破解聯邦學習中的辛普森悖論,浙大提出反事實學習新框架FedCFA

机器之心發表於2025-01-13

圖片

AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

江中華,浙江大學軟體學院碩士生二年級,導師為張聖宇老師。研究方向為大小模型端雲協同計算。張聖宇,浙江大學平臺「百人計劃」研究員。研究方向包括大小模型端雲協同計算,多媒體分析與資料探勘。

隨著機器學習技術的發展,隱私保護和分散式最佳化的需求日益增長。聯邦學習作為一種分散式機器學習技術,允許多個客戶端在不共享資料的情況下協同訓練模型,從而有效地保護了使用者隱私。然而,每個客戶端的資料可能各不相同,有的資料量大,有的資料量小;有的資料特徵豐富,有的資料特徵單一。這種資料的異質性和不平衡性(Non-IID)會導致一個問題:本地訓練的客戶模型忽視了全域性資料中明顯的更廣泛的模式,聚合的全域性模型可能無法準確反映所有客戶端的資料分佈,甚至可能出現「辛普森悖論」—— 多端各自資料分佈趨勢相近,但與多端全域性資料分佈趨勢相悖。

為了解決這一問題,來自浙江大學人工智慧研究所的研究團隊提出了 FedCFA,一個基於反事實學習的新型聯邦學習框架。

FedCFA 引入了端側反事實學習機制,透過在客戶端本地生成與全域性平均資料對齊的反事實樣本,緩解端側資料中存在的偏見,從而有效避免模型學習到錯誤的特徵 - 標籤關聯。該研究已被 AAAI 2025 接收。

圖片
  • 論文標題:FedCFA: Alleviating Simpson’s Paradox in Model Aggregation with Counterfactual Federated Learning
  • 論文連結:https://arxiv.org/abs/2412.18904
  • 專案地址:https://github.com/hua-zi/FedCFA

辛普森悖論

辛普森悖論(Simpson's Paradox)是一種統計現象。簡單來說,當你把資料分成幾個子組時,某些趨勢或關係在每個子組中表現出一致的方向,但在整個資料集中卻出現了相反的趨勢。
圖片
圖 1:辛普森悖論。在全域性資料集上觀察到的趨勢在子集上消失 / 逆轉,聚合的全域性模型無法準確反映全域性資料分佈

在聯邦學習中,辛普森悖論可能會導致全域性模型無法準確捕捉到資料的真實分佈。例如,某些客戶端的資料中存在特定的特徵 - 標籤關聯(如顏色與動物種類的關係),而這些關聯可能在全域性資料中並不存在。因此,直接將本地模型匯聚成全域性模型可能會引入錯誤的學習結果,影響模型的準確性。

如圖 2 所示。考慮一個用於對貓和狗影像進行分類的聯邦學習系統,涉及具有不同資料集的兩個客戶端。客戶端 i 的資料集主要包括白貓和黑狗的影像,客戶端 j 的資料集包括淺灰色貓和棕色狗的影像。對於每個客戶端而言,資料集揭示了類似的趨勢:淺色動物被歸類為「貓」,而深色動物被歸類為「狗」。這導致聚合的全域性模型傾向於將顏色與類別標籤相關聯併為顏色特徵分配更高的權重。然而,全域性資料分佈引入了許多不同顏色的貓和狗的影像(例如黑貓和白狗),與聚合的全域性模型相矛盾。在全域性資料上訓練的模型可以很容易地發現動物顏色與特定分類無關,從而減少顏色特徵的權重。
圖片
圖 2:FedCFA 可以生成客戶端本地不存在的反事實樣本,防止模型學習到不正確的特徵 - 標籤關聯。

反事實學習

反事實(Counterfactual)就像是「如果事情發生了另一種情況,結果會如何?」 的假設性推理。在機器學習中,反事實學習透過生成與現實資料不同的虛擬樣本,來探索不同條件下的模型行為。這些虛擬樣本可以幫助模型更好地理解資料中的因果關係,避免學習到虛假的關聯。

反事實學習的核心思想是透過對現有資料進行干預,生成新的樣本,這些樣本反映了某種假設條件下的情況。例如,在影像分類任務中,我們可以改變影像中的某些特徵(如顏色、形狀等),生成與原圖不同的反事實樣本。透過讓模型學習這些反事實樣本,可以提高模型對真實資料分佈的理解,避免過擬合區域性資料的特點。

反事實學習廣泛應用於推薦系統、醫療診斷、金融風險評估等領域。在聯邦學習中,反事實學習可以幫助緩解辛普森悖論帶來的問題,使全域性模型更準確地反映整體資料的真實分佈。

FedCFA 框架簡介

為了解決聯邦學習中的辛普森悖論問題,FedCFA 框架透過在客戶端生成與全域性平均資料對齊的反事實樣本,使得本地資料分佈更接近全域性分佈,從而有效避免了錯誤的特徵 - 標籤關聯。

如圖 2 所示,透過反事實變換生成的反事實樣本使區域性模型能夠準確掌握特徵 - 標籤關聯,避免區域性資料分佈與全域性資料分佈相矛盾,從而緩解模型聚合中的辛普森悖論。從技術上講,FedCFA 的反事實模組,選擇性地替換關鍵特徵,將全域性平均資料整合到本地資料中,並構建用於模型學習的反事實正 / 負樣本。具體來說,給定本地資料,FedCFA 識別可有可無 / 不可或缺的特徵因子,透過相應地替換這些特徵來執行反事實轉換以獲得正 / 負樣本。透過對更接近全域性資料分佈的反事實樣本進行對比學習,客戶端本地模型可以有效地學習全域性資料分佈。然而,反事實轉換面臨著從資料中提取獨立可控特徵的挑戰。一個特徵可以包含多種型別的資訊,例如動物影像的一個畫素可以攜帶顏色和形狀資訊。為了提高反事實樣本的質量,需要確保提取的特徵因子只包含單一資訊。因此,FedCFA 引入因子去相關損失,直接懲罰因子之間的相關係數,以實現特徵之間的解耦。
圖片
全域性平均資料集的構建

為了構建全域性平均資料集,FedCFA 利用了中心極限定理(Central Limit Theorem, CLT)。根據中心極限定理,若從原資料集中隨機抽取的大小為 n 的子集平均值記為圖片,則當 n 足夠大時,圖片的分佈趨於正態分佈,其均值為 μ,方差圖片,即:圖片,其中 µ 和圖片是原始資料集的期望和方差。

當 n 較小時,圖片能更精細地捕捉資料集的區域性特徵與變化,特別是在保留資料分佈尾部和異常值附近的細節方面表現突出。相反,隨著 n 的增大,圖片的穩定性顯著提升,其方差明顯減小,從而使其作為總體均值 𝜇 的估計更為穩健可靠,對異常值的敏感度大幅降低。此外,在聯邦學習等分散式計算場景中,為了實現通訊成本的有效控制,選擇較大的 n 作為樣本量被視為一種最佳化策略。

基於上述分析,FedCFA 按照以下步驟構建一個大小為 B 的全域性平均資料集,以此近似全域性資料分佈:

1.本地平均資料集計算:每個客戶端將其本地資料集隨機劃分為 B 個大小為圖片的子集圖片,其中圖片為客戶端資料集大小。對於每個子集,計算其平均值圖片。由此,客戶端能夠生成本地平均資料集圖片以近似客戶端原始資料的分佈。

2.全域性平均資料集計算:伺服器端則負責聚合來自多個客戶端的本地平均資料,並採用相同的方法計算出一個大小為 B 的全域性平均資料集圖片,該資料集近似了全域性資料的分佈。對於標籤 Y,FedCFA 採取相同的計算策略,生成其對應的全域性平均資料標籤圖片。最終得到完整的全域性平均資料集圖片

反事實變換模組
圖片
圖 3:FedCFA 中的本地模型訓練流程

FedCFA 中的本地模型訓練流程如圖 3 所示。反事實變換模組的主要任務是在端側生成與全域性資料分佈對齊的反事實樣本:

1. 特徵提取:使用編碼器(Encoder)從原始資料中提取特徵因子圖片
2. 選擇關鍵特徵:計算每個特徵在解碼器(Decoder)輸出層的梯度,選擇梯度小 / 大的 topk 個特徵因子作為可替換的因子,使用圖片將選定的小 / 大梯度因子設定為零,以保留需要的因子
3. 生成反事實樣本:用 Encoder 提取的全域性平均資料特徵替換可替換的特徵因子,得到反事實正 / 負樣本,對於正樣本,標籤不會改變。對於負樣本,使用加權平均值來生成反事實標籤:
圖片
因子去相關損失

同一畫素可能包含多個資料特徵。例如,在動物影像中,一個畫素可以同時攜帶顏色和外觀資訊。為了提高反事實樣本的質量,FedCFA 引入了因子去相關(Factor Decorrelation, FDC)損失,用於減少提取出的特徵因子之間的相關性,確保每個特徵因子只攜帶單一資訊。具體來說,FDC 損失透過計算每對特徵之間的皮爾遜相關係數(Pearson Correlation Coefficient)來衡量特徵的相關性,並將其作為正則化項加入到總損失函式中。

給定一批資料,用圖片來表示第 i 個樣本的所有因子。圖片表示第 i 個樣本的第 j 個因子。將同一批次中每個樣本的相同指標 j 的因子視為一組變數圖片。最後,使用每對變數的 Pearson 相關係數絕對值的平均值作為 FDC 損失:
圖片
其中 Cov (・) 是協方差計算函式,Var (・) 是方差計算函式。最終的總損失為:
圖片
實驗結果

實驗採用兩個指標:500 輪後的全域性模型精度 和 達到目標精度所需的通訊輪數,來評估 FedCFA 的效能。
圖片
圖片
圖片
實驗基於 MNIST 構建了一個具有辛普森悖論的資料集。具體來說,給 1 和 7 兩類影像進行上色,並按顏色深淺劃分給 5 個客戶端。每個客戶端的資料中,數字 1 的顏色都比數字 7 的顏色深。隨後預訓練一個準確率 96% 的 MLP 模型,作為聯邦學習模型初始模型。讓 FedCFA 與 FedAvg,FedMix 兩個 baseline 作為對比,在該資料集上進行訓練。如圖 5 所示,訓練過程中,FedAvg 和 FedMix 均受辛普森悖論的影響,全域性模型準確率下降。而 FedCFA 透過反事實轉換,可以破壞資料中的虛假的特徵 - 標籤關聯,生成反事實樣本使得本地資料分佈靠近全域性資料分佈,模型準確率提升。
圖片
圖 4: 具有辛普森悖論的資料集
圖片
圖 5: 在辛普森悖論資料集上的全域性模型 top-1 準確率

消融實驗
圖片
圖片
圖 6:因子去相關 (FDC) 損失的消融實驗

相關文章