As a reader --> TabDDPM: Modelling Tabular Data with Diffusion Models

阿洛萌萌哒發表於2024-04-23
  • 📌論文分類3:
    TabDDPM——一個擴散模型,它可以普遍應用於任何表格資料集並處理任何特徵型別。https://github.com/yandex-research/tab-ddpm
    • 論文名稱 TabDDPM: Modelling Tabular Data with Diffusion Models
    • 作者 Akim Kotelnikov, Dmitry Baranchuk, Ivan Rubachev, Artem Babenko
    • 期刊名稱 International Conference on Machine Learning. PMLR, 2023: 17564-17579.
    • 簡要摘要
      去噪擴散機率模型正在成為許多重要資料模態的主要生成建模正規化。作為計算機視覺社群中最流行的模型,擴散模型最近在其他領域得到了一些關注,包括語音、自然語言處理和類圖資料。這項工作研究擴散模型的框架是否有利於一般的表格問題,其中資料點通常由異構特徵的向量表示。表格資料固有的異質性使得精確建模非常具有挑戰性,因為單個特徵可能具有完全不同的性質,即有些特徵可能是連續的,有些特徵可能是離散的。
      為了處理這樣的資料型別,本文引入了TabDDPM——一個擴散模型,它可以普遍應用於任何表格資料集並處理任何特徵型別。在廣泛的基準上對TabDDPM進行了廣泛的評估,並證明其優於現有的GAN/VAE替代品,這與擴散模型在其他領域的優勢是一致的。
    • ✏️論文內容
      • 【內容1】
        • 💡Introduction
        • 去噪擴散機率模型(DDPM) 最近在生成建模社群中成為一個非常有研究興趣的物件,因為它們在個體樣本的真實性和多樣性方面往往優於其他方法。DDPM最令人印象深刻的成功是在自然影像領域,擴散模型的優勢在各種應用被成功利用,比如著色、圖片修復、分割、超解析度、語義編輯,等等。除了計算機視覺之外,DDPM框架還在其他領域進行了研究,例如NLP、波形訊號處理、分子圖、時間序列,……,這證明了擴散模型在廣泛問題中的普遍性。
        • 本工作旨在研究DDPM的普遍性是否可以擴充套件到一般表格問題的情況,這些問題在各種工業應用中無處不在,包括由一組異構特徵描述的資料。對於許多此類應用,由於現代隱私法規(如GDPR)禁止釋出真實使用者資料,而生成模型生成的合成資料可以共享,因此對高質量生成模型的需求尤為迫切。然而,由於單個特徵的異質性和典型表格資料集的相對較小的規模,訓練高質量的表格資料模型可能比計算機視覺或NLP更具挑戰性。
        • 本文表明,儘管存在這兩種複雜性,但擴散模型可以成功地近似表格資料的典型分佈,從而在大多數基準測試中獲得最先進的效能。更詳細地說,這項工作的主要貢獻如下:
          • 1. 介紹TabDDPM——一個簡單的用於表格問題的DDPM設計,它可以應用於任何表格任務,並處理混合資料型別,包括數值和分類特徵。
          • 2. 證明了TabDDPM優於為表格資料設計的替代方法,包括基於GAN和基於VAE的方法,並在幾個資料集上說明了這種優勢的來源。
          • 3. 觀察到基於淺插值的方法,例如SMOTE (Chawla等人,2002),產生了令人驚訝的有效合成資料,提供了具有競爭力的高機器學習效率。結果表明,與SMOTE相比,當使用合成資料代替無法共享的真實使用者資料時,TabDDPM的資料更適合涉及隱私的場景。
      • 【內容2】
        • 💡Related Work
          • Diffusion models
            一種生成建模的範例,旨在透過馬爾可夫鏈的端點近似目標分佈,它從給定的引數分佈開始,通常是標準高斯分佈。每個馬爾可夫步驟都是由一個深度神經網路執行的,該網路有效地學習用已知的高斯核反轉擴散過程。Ho等人證明了擴散模型和分數匹配的等價性,表明它們是透過迭代去噪過程將簡單已知分佈逐漸轉換為目標分佈的兩種不同視角。近期的幾項工作開發了更強大的模型架構以及不同的高階學習協議,這導致DDPM在計算機視覺領域的生成質量和多樣性方面優於GAN。這項工作證明了人們也可以成功地將擴散模型用於表格問題。
          • Generative models for tabular problems
            目前是機器學習社群的一個活躍的研究方向,因為許多表格任務對高質量的合成資料有很大的需求。首先,表格資料集通常在大小上是有限的,不像在視覺或NLP問題中,在網際網路上有大量的“額外”資料。其次,適當的合成資料集不包含實際的使用者資料。
            因此,它們不受類似GDPR的監管,可以在不違反匿名性的情況下公開共享。最近的工作已經開發了大量的模型,包括表格VAEs,和基於GAN的方法。透過對大量公共基準進行廣泛的評估,TabDDPM優於現有的替代方案,而且通常有很大的優勢。
          • “Shallow” synthetics generation
            與非結構化影像或自然文字不同,表格資料通常是結構化的,即單個特徵通常是可解釋的,並且不清楚它們的建模是否需要幾層“深度”架構。因此,簡單的插值技術,如SMOTE (Chawla等人,2002)(最初是為了解決類不平衡而提出的)可以作為簡單而強大的解決方案,如(Camino等人,2020)所示,SMOTE在小類過取樣方面優於表格GAN。在本文實驗中,從隱私保護的角度證明了TabDDPM合成資料比用插值技術生產的合成資料的優勢。
      • 【內容3】
        • 💡Background
          • Diffusion models
          • Gaussian diffusion models
          • Multinomial diffusion models
      • 【內容4】
        • 💡TabDDPM

          描述TabDDPM的設計以及影響模型有效性的主要超引數。
          • TabDDPM採用多項擴散法對分類和二值特徵進行建模,採用高斯擴散法對數值特徵進行建模。更詳細地說,對於表格資料x:

          • 對於預處理,使用scikit-learn庫中的高斯分位數變換,每個分類特徵由一個單獨的前向擴散過程處理,即所有特徵的噪聲分量是獨立取樣的。TabDDPM中的反向擴散步驟是由一個多層神經網路建模的,該神經網路的輸出維度與x0相同,其中前N_num個元素是高斯擴散的ε的預測,其餘的是多項式擴散的x_cati^ohe的預測。
          • 分類問題的TabDDPM模型如圖1所示。模型是透過最小化高斯擴散項的均方誤差總和,和每個多項式擴散項的KL散度訓練的。多項擴散的總損失另外除以分類特徵的數目。

          • 對於分類資料集,使用類條件模型,也就是說,pθ(xt−1|xt, y)是習得的;對於迴歸資料集,將目標值作為附加的數值特徵,並學習聯合分佈。
          • 為了對反向過程建模,使用了一個簡單的MLP架構,改編自(gorishny等人,2021):

          • 如(Nichol, 2021; Dhariwal & Nichol, 2021)所述,表格輸入x_in,時間步長t和類標籤y的處理如下:

          • 其中SinTimeEmb指正弦時間嵌入,如(Nichol, 2021;Dhariwal & Nichol, 2021)所述,維度為128。方程5中的所有線性層都有一個固定的投影維度128。
          • TabDDPM中的超引數是必不可少的,因為在實驗中觀察到它們對模型有效性有很強的影響。表1列出了主要的超引數以及每個超引數的搜尋空間,建議使用這些超引數。實驗部分詳細描述了微調過程。


      • 【內容5】
        • 💡Experiments
        • Datasets

        • Baselines
          • TVAE (Xu et al ., 2019)——用於表格資料生成的最先進的變分自動編碼器。據我們所知,目前還沒有一種替代的類似於VAE的模型能夠超越TVAE並且擁有開原始碼。
          • CTGAN(Xu et al., 2019)——可以說是最流行和最知名的基於GAN的合成資料生成模型。
          • CTABGAN(Zhao et al., 2021)——最近一種基於GAN的模型,在各種基準測試中表現優於現有的表格式GAN。這種方法不能處理迴歸任務。
          • CTABGAN+(Zhao et al., 2022)——CTABGAN模型的擴充套件,發表在最近的預印本中。我們不知道還是否有CTABGAN+之後提出的基於GAN的表格資料模型,並有一個公開的原始碼。
          • SMOTE(Chawla et al., 2002)——一種基於“淺”插值的方法,它“生成”一個合成點,作為真實資料點和資料集中第k個最近鄰居的凸組合。該方法最初是針對小類過取樣提出的。這裡將其推廣到合成資料生成,作為簡單的完整性檢查,即,透過插入來自同一類的兩個樣本來“生成”新的合成樣本。對於迴歸問題,透過目標變數的中位數將資料分成兩類。
        • Evaluation measure
          • 主要評估指標是機器學習(ML)的效率(或效用)。更詳細地說,機器學習效率量化了在合成資料上訓練並在真實測試集上評估的分類或迴歸模型的效能。直觀地說,在高質量合成材料上訓練的模型應該比在真實資料上訓練的模型更有競爭力(甚至更好)。本文使用兩種評估協議來計算機器學習效率。
          • 在第一種方案中,計算了一組不同ML模型(邏輯迴歸、決策樹等)的平均效率。在第二個方案中,僅使用CatBoost模型評估機器學習效率,該模型可以說是領先的GBDT實現,在表格任務上提供最先進的效能。【第5.2節的實驗中表明,使用第二種協議是至關重要的,而第一種協議往往會產生誤導。】
          • 為了調整TabDDPM和基線的超引數,使用Optuna庫。調優過程由在保留驗證資料集上生成的合成資料的ML效率值指導(分數在五個不同的取樣種子上平均)。表1報告了TabDDPM所有超引數的搜尋空間。此外,證明使用CatBoost指南調優超引數不會引入任何型別的“CatBoost偏置”,而Catboost-微調的TabDDPM生產的合成資料也優於其他模型,如MLP。
        • 1.Qualitative comparison
          • 定性研究TabDDPM與TVAE、CTABGAN+基線相比,對個體和聯合特徵分佈的建模能力。特別是,對於每個資料集,從TabDDPM、 TVAE和CTABGAN+中生成與特定資料集中的真實訓練集相同大小的合成資料集。對於分類資料集,每個類別根據其在真實資料集中的比例進行取樣。在圖2中視覺化了真實資料和合成資料的典型單個特徵分佈。為了完整起見,給出了不同型別和分佈的特徵。

          • 在大多數情況下,與TVAE和CTABGAN+相比,TabDDPM產生的特徵分佈更真實。對於(1)均勻分佈的數值特徵,(2)具有高基數的分類特徵,以及(3)結合連續和離散分佈的混合型別特徵,優勢更加明顯。
          • 此外,還視覺化了對不同資料集的真實資料和合成資料計算的關聯矩陣之間的差異,參見圖3。

          • 為了計算相關矩陣,使用皮爾遜相關係數來表示數值相關性,使用相關比率來表示分類數值情況,使用Theil’s U統計量來表示分類特徵。與CTABGAN+和TVAE相比,TabDDPM生成的合成資料集具有更現實的兩兩相關性。這些例項表明,TabDDPM模型比其他模型更靈活,併產生更好的合成資料。還遵循(Zhao et al ., 2021)並測量數值特徵之間的Wasserstein距離和分類特徵之間的Jensen-Shannon散度,報告了相關矩陣之間的L2距離。結果在表3中顯示為所有資料集的平均排名(越低越好)。排名越低,WD、JS散度和L2距離越低。

        • 2.Machine Learning efficiency
          • 將TabDDPM與其他生成模型在機器學習效率方面進行比較。從每個生成模型中,按表1的比例取樣一個具有真實訓練集大小的合成資料集。然後使用這些合成資料來訓練分類/迴歸模型,然後使用真實的測試集對其進行評估。實驗中,分類效能用F1分數來評價,迴歸效能用R2分數來評價。使用兩種方案:
            • 1.計算一組不同ML模型的平均ML效率,該集合包括決策樹、隨機森林、邏輯迴歸(或Ridge迴歸)和來自scikit-learn庫的MLP模型。
            • 2.根據當前最先進的表格資料模型計算機器學習效率。具體來說,考慮了CatBoost和(gorishny等人,2021)的MLP架構進行評估。CatBoost和MLP超引數使用來自(gorishny等人,2021)的搜尋空間在每個資料集上進行徹底調優。這種評估協議更可靠地展示了合成資料的實用價值,因為在大多數實際場景中,從業者對使用弱和次優分類器/迴歸器不感興趣。
          • 兩種方案計算的ML效率值如表4、5所示。為了計算每個值,對合成生成的五個隨機種子的結果進行平均;對於每個生成的資料集,對訓練分類器/迴歸器的十個隨機種子進行平均。


            • 在這兩種評估方案中,TabDDPM在大多數資料集上都明顯優於TVAE和CTABGAN+,這突出了表格資料的擴散模型的優勢,並在先前的工作中證明了其他領域。
            • 基於插值的SMOTE方法表現出與TabDDPM相競爭的效能,並且通常顯著優於GAN/VAE方法。
            • 有趣的是,大多數關於表格資料生成模型的先前工作都沒有與SMOTE進行比較,而SMOTE似乎是一個簡單的基線,這是具有挑戰性的。
            • 雖然許多先前的工作使用第一種評估方案來計算機器學習效率,但本文認為第二種(使用最先進的模型)更合適。表4、5顯示,第一種方案的分類/迴歸效能的絕對值要低得多,即在考慮的基準測試中,弱分類器/迴歸器實質上不如CatBoost。因此,人們很難使用這些次優模型來代替CatBoost,並且它們的效能值對從業者來說是沒有資訊的。此外,在第一種方案中,對合成資料的訓練往往比對真實資料的訓練更有利。這給人一種印象,即生成模型產生的資料比真實資料更有價值。然而,在大多數實際場景中,當使用調優的ML模型時,情況並非如此。
          • 總的來說,TabDDPM提供了最先進的生成效能,可以用作高質量合成資料的來源。有趣的是,就機器學習效率而言,一個簡單的“淺”SMOTE方法與TabDDPM競爭,這就提出了一個問題,即是否需要複雜的深度生成模型。下面對這個問題給出一個肯定的答案。
        • 3. Privacy
          • 研究TabDDPM在涉及隱私的設定中,例如,在不洩露個人或敏感資訊的情況下共享資料。在這些設定中,人們對不顯示原始資料集記錄的高質量合成資料感興趣。
          • 用與最近記錄的平均距離來衡量生成資料的隱私性。具體來說,對於每個合成樣本,得到到真實記錄的最小L2距離。平均DCR在所有生成的樣本上取這些距離的平均值。低DCR值表明合成樣本基本上模擬了一些真實的資料點,並且可能違反隱私要求。較高的DCR值表示生成模型可以生成“新”記錄,而不僅僅是真實資料的近副本。請注意,分佈外資料,例如隨機噪聲,也將提供高DCR。因此,DCR需要與ML效率一起考慮。
          • 表7給出了TabDDPM、SMOTE、CTABGAN+和TVAE的DCR值。觀察到TabDDPM比SMOTE更私密,比GAN/VAE替代品更不私密,將此歸因於基於GAN/VAE基線的ML效用顯著降低。

          • 由於SMOTE計算的是真實記錄的凸組合,原始的DCR度量可能會降低SMOTE的隱私性。為了解決這個問題,使用真實資料在每個資料集上預訓練一個MLP模型。然後,使用該模型從合成資料中提取特徵,並在預訓練模型的潛在空間中測量DCR。表14給出了MLP特徵的平均DCR值。結果與表7基本一致,並沒有改變前面結論。此外,本文還視覺化了圖4中最小合成距離的直方圖。對於SMOTE,大多數距離值都集中在零附近,而TabDDPM樣本離實際資料點明顯更遠。

          • 下面衡量一個完整黑箱隱私攻擊的成功率(見表6)。

          • 攻擊的目的是推斷一條記錄是否屬於其原始訓練資料。結果表明:TabDDPM比SMOTE更能抵抗這種完整的黑盒攻擊。所有這些實驗都證實,TabDDPM在涉及隱私的場景中顯著優於SMOTE,並且仍然提供最先進的機器學習效率。
      • 【內容6】
        • 💡Limitations and discussion
          • 本文所提出的方法並沒有假裝是一個提供高隱私和高ML實用性的一體化解決方案。實驗表明,TabDDPM比“淺”SMOTE更隱私,但TabDDPM的資料是否能滿足現實世界中涉及隱私的應用,沒有給出明確的答案。因此,DDPM生成的資料的隱私問題需要進一步研究。此外,本文中使用的DCR並不是一種最終的隱私措施,也沒有涵蓋一些關鍵的用例。例如,記錄之間的L2距離沒有考慮單個特徵的重要性,如果某些敏感特徵重合,則無法檢測洩漏。
          • 此外,在本文的工作中,使用多項擴散來處理分類特徵。然而,也存在其他方法,例如(Chen et al ., 2022; Campbell et al, 2022; Zheng & Charoenphakdee, 2022)。這些技術中的每一種都適用於TabDDPM,並且可能是一個有趣的研究方向。對於數值特徵,TabDDPM的可能擴充套件可以從(Nazabal et al, 2020)中得到啟發,該特徵區分了不同型別的數值變數,即實值、正實值或序數。
    • 總結
      • 本文探討了擴散建模框架在表格資料領域的應用前景。特別地,描述了可以處理由數值特徵和分類特徵組成的混合資料型別的DDPM設計。對於大多數考慮的基準,與基於GAN/VAE的競爭對手相比,TabDDPM生成的合成資料始終具有更高的質量。有趣的是,像SMOTE這樣的淺插值技術已經證明了有競爭力的ML實用程式,需要被視為簡單而有效的基線。然而,在必須確保資料隱私的設定中,TabDDPM優於SMOTE。
    • 附錄
      • A. MLP evaluation and tuning
      • B. Additional results
      • C. Additional visualizations
      • D. Distance to Closest Record using pretrained MLP features
      • E. Hyperparameters Search Spaces
      • F. Datasets
      • G. Environment and Runtime

相關文章