【AAAI 2024】解鎖深度表格學習(Deep Tabular Learning)的關鍵:算術特徵互動

阿里云大数据AI技术發表於2024-03-15

近日,阿里雲人工智慧平臺PAI與浙江大學吳健、應豪超老師團隊合作論文《Arithmetic Feature Interaction is Necessary for Deep Tabular Learning》正式在國際人工智慧頂會AAAI-2024上發表。本項工作聚焦於深度表格學習中的一個核心問題:在處理結構化表格資料(tabular data)時,深度模型是否擁有有效的歸納偏差(inductive bias)。我們提出算術特徵互動(arithmetic feature interaction)對深度表格學習是至關重要的假設,並透過建立合成資料集以及設計實現一種支援上述互動的AMFormer架構(一種修改的Transformer架構)來驗證這一假設。實驗結果表明,AMFormer在合成資料集表現出顯著更優的細粒度表格資料建模、訓練樣本效率和泛化能力,並在真實資料的對比上超過一眾基準方法,成為深度表格學習新的SOTA(state-of-the-art)模型。

1、背景

【AAAI 2024】解鎖深度表格學習(Deep Tabular Learning)的關鍵:算術特徵互動

圖1:結構化表格資料示例,引用自[Borisov et al.]

結構化表格資料——這些資料往往以表(Table)的形式儲存於資料庫或數倉中——作為一種在金融、市場營銷、醫學科學和推薦系統等多個領域廣泛使用的重要資料格式,其分析一直是機器學習研究的熱點。表格資料(圖1)通常同時包含數值型(numerical)特徵和類目型(categorical)特徵,並往往伴隨有特徵缺失、噪聲、類別不平衡(class imblanance)等資料質量問題,且缺少時序性、區域性性等有效的先驗歸納偏差,極大地帶來了分析上的挑戰。傳統的樹整合模型(如,XGBoost、LightGBM、CatBoost)因在處理資料質量問題上的魯棒性,依然是工業界實際建模的主流選擇,但其效果很大程度依賴於特徵工程產出的原始特徵質量。

隨著深度學習的流行,研究者試圖引入深度學習端到端建模,從而減少在處理表格資料時對特徵工程的依賴。相關的研究工作至少可以可以分成四大類:(1)在傳統建模方法中疊加深度學習模組(通常是多層感知機MLP),如Wide&Deep、DeepFMs;(2)形狀函式(shape function)採用深度學習建模的廣義加性模型(generalized additive model),如 NAM、NBM、SIAN;(3)樹結構啟發的深度模型,如NODE、Net-DNF;(4)基於Transformer架構的模型,如AutoInt、DCAP、FT-Transformer。儘管如此,深度學習在表格資料上相比樹模型的提升並不顯著且持續,其有效性仍然存在疑問,表格資料因此被視為深度學習尚未征服的最後堡壘。

2、算術特徵互動在深度表格學習的“必要性”

我們認為現有的深度表格學習方法效果不盡如人意的關鍵癥結在於沒有找到有效的建模歸納偏差,並進一步提出算術特徵互動對深度表格學習是至關重要的假設。本節介紹我們透過建立一個合成資料集,並對比引入算數特徵互動前後的模型效果,來驗證該假設。

【AAAI 2024】解鎖深度表格學習(Deep Tabular Learning)的關鍵:算術特徵互動

圖2:合成資料集上的結果對比。圖中+x%表示AMFormer相比Transformer的相對提升。

在上述資料中,我們將引入了算數特徵互動的AMFormer架構與經典的XGBoost和Transformer架構對比。實驗結果顯示:

以上結果共同證實了算術特徵互動在深度表格學習中的顯著意義。

3、演算法架構

【AAAI 2024】解鎖深度表格學習(Deep Tabular Learning)的關鍵:算術特徵互動

圖3:AMFormer架構,其中L表示模型層數。

本節介紹AMFormer架構(圖3),並重點介紹算數特徵互動的引入。AMFormer架構借鑑了經典的Transformer框架,並引入了Arithmetic Block來增強模型的算術特徵互動能力。在AMFormer中,我們首先將原始特徵轉換為具有代表性的嵌入向量,對於數值特徵,我們使用一個1輸入d輸出的線性層;對於類別特徵,則使用一個d維的嵌入查詢表。之後,這些初始嵌入透過L個順序層進行處理,這些層增強了嵌入向量中的上下文和互動元素。每一層中的算術模組採用了並行的加法和乘法注意力機制,以刻意促進算術特徵之間的互動。為了促進梯度流動和增強特徵表示,我們保留了殘差連線和前饋網路。最終,依據這些豐富的嵌入向量,AMFormer使用分類或迴歸頭部生成最終輸出。

算術模組的關鍵元件包括並行注意力機制和提示標記。為了補償需要算術特徵互動的特徵,我們在AMFormer中配置了並行注意力機制,這些機制負責提取有意義的加法和乘法互動候選者。這些互動候選隨著會沿著候選維度被串聯(concatenate)起來,並透過一個下采樣的線性層進行融合,使得AMFormer的每一層都能有效捕捉算術特徵互動,即特徵上的四則演算法運算。為了防止由特徵冗餘引起的過擬合併提升模型在超大規模特徵資料集上的伸縮,我們放棄了原始Transformer架構中平方複雜度的自注意力機制,而是使用兩組提示向量(prompt token vectors)作為加法和乘法查詢。這種方法為AMFormer提供了有限的特徵互動自由度,並且作為一個附帶效果,最佳化了記憶體佔用和訓練效率。

以上是AMFormer在架構層引入的主要創新,關於模型更詳細的實現細節可以參考原文以及我們的開源實現。


4、進一步實驗結果

【AAAI 2024】解鎖深度表格學習(Deep Tabular Learning)的關鍵:算術特徵互動

表1:真實資料集統計以及評估指標。

為了進一步展示AMFormer的效果,我們挑選了四個真實資料集進行實驗。被挑選資料集覆蓋了二分類、多分類以及迴歸任務,資料集統計如表1所示。

【AAAI 2024】解鎖深度表格學習(Deep Tabular Learning)的關鍵:算術特徵互動

表2:AMFormer以及基準方法的效能對比,其中括號內的數字表示該方法在當前資料集上表現的排名,最優以及次優的結果分別以加粗以及下劃線突出。

我們一共測試了包含傳統樹模型(XGBoost)、樹架構深度學習方法(NODE)、高階特徵互動(DCN-V2、DCAP)以及Transformer派生架構(AutoInt、FT-Trans)在內的六個基準演算法以及兩個AMFormer實現(分別選擇AutoInt、FT-Trans做基礎架構,即AMF-A和AMF-F),結果彙總在表2中。

在一系列對比實驗中,AMFormer表現更突出。結果顯示,基於MLP的深度學習方法如DCN-V2在表格資料上的效能不盡如人意,而基於Transformer架構的模型顯示出更大的潛力,但未能始終超過樹模型XGBoost。我們的AMFormer在四個不同的資料集上,與所有六個基準模型相比,表現一致更優:在分類任務中,它將AutoInt和FT-transformer的準確率或AUC提升至少0.5%,最高達到1.23%(EP)和4.96%(CO);在迴歸任務中,它也顯著減少了平均平方誤差。相比其它深度表格學習方法,AMFormer具有更好的魯棒和穩定性,這使得在效能排序中AMFormer斷層式優於其它基準演算法,這些實驗結果充分證明了AMFormer在深度表格學習中的必要性和優越性。


5、結論

本工作研究了深度模型在表格資料上的有效歸納偏置。我們提出,算術特徵互動對於表格深度學習是必要的,並將這一理念融入Transformer架構中,建立了AMFormer。我們在合成資料和真實世界資料上驗證了AMFormer的有效性。合成資料的結果展示了其在精細表格資料建模、訓練資料效率以及泛化方面的優越能力。此外,對真實世界資料的廣泛實驗進一步確認了其一致的有效性。因此,我們相信AMFormer為深度表格學習設定了強有力的歸納偏置。

  • 進一步閱讀:



來自 “ ITPUB部落格 ” ,連結:https://blog.itpub.net/70004426/viewspace-3009120/,如需轉載,請註明出處,否則將追究法律責任。

相關文章