聯邦學習:多工思想與聚類聯邦學習

orion發表於2022-03-15

1.導引

電腦科學一大定律:許多看似過時的東西可能過一段時間又會以新的形式再次迴歸。

在聯邦學習領域,許多傳統機器學習已經討論過的問題(甚至一些90年代和00年代的論文)都可以被再次被發明一次。比如我們會發現聚類聯邦學習和多工學習之間就有千絲萬縷的聯絡。

2. 多工學習回顧

我們在部落格《基於正則化的多工學習》中介紹了標準多工學習的核心是多工個性化權重+知識共享[1]。如多工學習最開始提出的模型即為一個共享表示的神經網路:

多層前饋網路構建共享特徵表示

而多工學習中有一種方法叫聚類多工學習。聚類多工學習基本思想為:將任務分為若個個簇,每個簇內部的任務在模型引數上更相似。 最早的聚類多工學習的論文[2]為一種一次性聚類(one-shot clustering),即將任務聚類和引數學習分為了兩個階段:第一階段,根據在各任務單獨學習得到的模型來聚類,確定不同的任務簇。第二階段,聚合同一個任務簇中的所有訓練資料,以學習這些任務的模型。這種方法把任務聚類和模型引數學習分為了兩個階段,可能得不到最優解,因此後續工作都採用了迭代式聚類(iterative clustering),即在訓練迭代中同時學習任務聚類和模型引數的方法,

我們著重介紹Bakker等人(2003)[3]提出了一個多工貝葉斯神經網路(multi-task Bayesian neural network),其結構與我們前面所展現的共享表示的神經網路相同,其亮點在於基於連線隱藏層和輸出層的權重採用高斯混合模型(Gaussian mixture model)對任務進行分組。若給定資料集\(\mathcal{D} = \{D_t\}, t=1,...,T\),設隱藏層維度為\(h\),輸出層維度為\(T\)\(\mathbf{W}\in \mathbb{R}^{T\times (h + 1)}\)代表隱藏層到輸出層的權重矩陣(結合了偏置)。我們假定每個任務對應的權重向量\(\mathbf{w}_t\)(\(\mathbf{W}\)的第\(t\)列)關於給定超引數獨立同分布,我們假定第\(t\)個任務先驗分佈如下:

\[\bm{w}_t \sim \mathcal{N}(\bm{w}_t | \bm{u}, \mathbf{\Sigma}) \tag{1} \]

這是一個高斯分佈,均值為\(\bm{u} \in \mathbb{R}^{h + 1}\),協方差矩陣\(\mathbf{\Sigma} \in \mathbb{R}^{ (h+1) \times (h+1)}\)
我們上面的定義其實假定了所有任務屬於一個簇,接下來我們假定我們有不同的簇(每個簇由相似的任務組成)。我們設有\(C\)個簇(cluster),則任務\(t\)的權重\(w_t\)\(C\)個高斯分佈的混合分佈:

\[\bm{w}_t \sim \sum_{c=1}^C \alpha_c \mathcal{N} (\bm{w}_t | \bm{u}_c, \mathbf{\Sigma}_c) \tag{2} \]

其中,每個高斯分佈可以被認為是描述一個任務簇。式\((9)\)中的\(\alpha_c\)代表了任務\(t\)被分為簇\(c\)的先驗概率,其中task clustering(如下面左圖所示)模型中所有任務對簇\(c\)的加權\(\alpha_c\)都相同;而task-depenent模型(如下面右圖所示)中各任務對簇\(c\)的加權\(\alpha_{tc}\)不同,且依賴於各任務特定的向量\(\bm{f}_t\)
混合高斯模型對任務進行分組

3. 潛在聚類結構的的前置假設

上面我們提到的這種來自於上個世紀90年代的多工學習思想並未過時,近年來在個性化聯邦學習中又重新煥發了生機。這篇文章我們就來歸納一下與多工學習思想關係密切的聯邦學習個性化方法之一——聚類聯邦學習

這類演算法採取了一個重要的假設,就是各個client之間的資料存在潛在的簇關係。在實驗中進行資料拆分時(我們只考慮圖片資料集),我們會先將資料按照常規Non-IID演算法劃分到client,然後對client中的圖片進行旋轉處理[8]。具體的,對於8個client,如果我們設定兩個簇,那麼對其中4個client中的全部圖片(包括訓練集/驗證集/測試集)都進行旋轉180°;如果設定四個簇,則我們對其中的client-1、client2中圖片不動,client3、client4中圖片旋轉90°,client5、client6資料旋轉180°,client7、client8資料旋轉270°。

這種劃分方式,將直接導致非聚類演算法(如FedAvg)精度大大降低,因為不同旋轉模式之間的client協作起來反而是有害的。此時需要使用我們下面即將介紹的聚類聯邦學習演算法。聚類聯邦學習演算法將同一個旋轉模式的client聚為一個類簇,簇內節點可以相互進行知識共享,不同簇之間的client之間不能進行知識共享。從而達到了即讓節點之間進行知識共享,又減少了不同旋轉模式之間的client相互“負影響”的目的。

(其實,很多論文這樣創造聚類結構,並以此宣稱自己的聚類演算法優於非聚類演算法如FedAvg,就好比醫生故意把病人的腿打斷了然後給他治病一樣,也算是屬於刻意去創造條件了,但或許科研就是這樣吧)

注意,雖然這裡假定每個聚類簇對應一個旋轉模式,但其實不同旋轉模式之間的資料仍然可能會有關聯。而這就為我們進行簇間的多工知識共享提供了動機

按照其聚類的時間,我們按照上文聚類多工學習的分發,可將聚類聯邦學習的方法也分為一次性聚類迭代式聚類

一次性聚類即在模型訓練開始之前就先對client資料進行聚類。因為client資料不出庫,常常將\(T\)個client初步訓練後的引數或得到的稀疏資料表徵發往server,然後由server將\(T\)個client聚為\(S_1,...,S_k\)。然後對每一個類\(k(k\in [K])\),執行傳統聯邦學習方法(如FedAvg[5]或二階方法[6][7])求解出最終的\(\hat{\theta}_k\)。注意, 此種方式預設每個聚類簇得到一個引數\(\hat{\theta}_k\)。同樣,類似我們上面對一次性聚類多工學習的分析,如果一旦開局的聚類演算法產生了錯誤的估計,那麼接下來的時間裡演算法將無法對其進行更正。而且,如果計算的是神經網路模型引數之間的距離,還會產生另外的問題,參見下文4.1部分。

動態式的聚類為則為在訓練的過程中一邊訓練,一邊根據模型引數\(\widetilde{\theta}_t\)的情況來動態調整聚類結果。這種方式既可以簇內任務直接廣播引數\(\theta_k\),也可以僅僅共享引數的變化量\(\Delta \theta_k\)

按照其任務劃分粒度的方式。我們可將方法分為節點間多工簇間多工。節點間多工類似聚類多工學習,它為將每個節點視為一個任務(訓練一個個性化模型),然後將節點劃分為多個簇,簇內任務間進行知識共享。簇間多工為將節點劃分為多個簇,假定簇內資料IID。然後簇內節點直接廣播引數,此時將每個簇視為一個任務,然後再在簇間進行知識共享。

4. 聚類聯邦學習經典論文閱讀

4.1 arXiv 2019:《Robust federated learning in a heterogeneous environment》[11]

深度多工學習例項1

本文采用的是一次性聚類。本論文需要事先指定聚類簇的個數\(K\),然後在訓練迭代開始前,常常將\(T\)個client初步訓練後的引數\(\{\widetilde{\theta}_t\}^T_{t=1}\)發往server,然後由server執行K-means演算法將\(T\)個client聚為\(S_1,...,S_k\)。然後對每一個類\(k(k\in [K])\),單獨執行傳統聯邦學習方法(如FedAvg[5]或二階方法[6][7])求解出最終的\(\hat{\theta}_k\)。注意, 此種方式預設每個聚類簇得到一個引數\(\hat{\theta}_k\)值得一提的是,雖然該演算法屬於(個性化)聚類聯邦學習,但我個人認為嚴格來說它與標準的多工學習還有一定距離,因為它的每個簇之間缺少知識共享。 不過它確實體現了一種“多工”的思想。

然而,這種方法常常使用類似於K-means的距離聚類方法,而從理論上講對於神經網路兩個模型擁有相近的引數距離,但其輸出可能大不相同(由於模型對於隱藏層單元的置換不變性)。而近期我做過實驗,單純比較神經網路之間的歐式/餘弦距離來確定模型的相似性效果非常差,不建議大家採用

4.2 TNNLS 2020:《Clustered Federated Learning: Model-Agnostic Distributed Multitask Optimization Under Privacy Constraints》[8]

深度多工學習例項1

本文采用的是迭代式聚類節點間多工,它在每輪迭代時,都會根據節點的引數相似度進行一次任務簇劃分,同一個任務簇共享引數的變化量 \(g\),以此既能達到每個節點訓練一個性化模型,又能完成資料相似的節點相互促進(知識共享),減少資料不相似節點之間的影響的目的。本論文的亮點是聚類簇的個數可以隨著迭代變化,無需實現指定任務簇個數\(K\)做為先驗。

本篇論文演算法的每輪通訊描述如下:

(1) 第\(t\)個client節點執行

  • 從server接收對應的簇引數\(g_k\)
  • \(\theta_{old}=\theta_t=\theta_t + g_k\)
  • 執行\(E\)個區域性epoch的SGD:

    \[\theta_t = \theta_t - \eta \nabla \mathcal{l}(\theta_t; b) \]

    (此處將區域性資料\(D_t\)劃分為多個\(b\))
  • \(\widetilde{g}_t = \theta_t-\theta_{old}\)發往server。

(2) server節點執行

  • \(T\)個client接收\(\widetilde{g}_1、\widetilde{g}_2,...\widetilde{g}_T\)

  • 對每一個簇\(k\),計算簇內平均引數變化:

\[\hat{g}_{k} = \frac{1}{|S_k|}\sum_{t\in S_k}\widetilde{g}_t \]

  • 將每個簇的簇內變化\(\hat{g}_k\)發往對應的client。

  • 根據不同節點引數變化量的餘弦距離\(\alpha_{i,j}=\frac{\langle \widetilde{g}_i, \widetilde{g}_j\rangle}{||\widetilde{g}_i||||\widetilde{g}_j||}\)重新劃分聚類簇。

4.3 NIPS 2020:《An Efficient Framework for Clustered Federated Learning》[8]

深度多工學習例項1

本文采用的是迭代式聚類簇間多工。本論文需要事先指定聚類簇的個數\(K\)。它在每輪迭代時,都會根據節點的引數相似度重新進行一次任務簇劃分,第\(k\)個任務簇對其所有節點廣播引數\(\theta_k\),也就是說每個聚類簇構成一個學習任務(負責學習屬於該任務簇的個性化模型)。

本篇論文演算法的每輪通訊描述如下:

(1) 第\(t\)個client節點執行

  • 從server接收簇引數\(\theta_1,\theta_2,...\theta_{K}\)

  • 估計其所屬的簇:\(\hat{k}=\underset{k\in[K]}{\text{argmin }}\mathcal{l}(\theta_k;D_t)\)

  • 對簇引數執行\(E\)個區域性epoch的SGD:

    \[\theta_k = \theta_k - \eta \nabla \mathcal{l}(\theta_k; b) \]

    (此處將區域性資料\(D_t\)劃分為多個\(b\))

  • 將最終得到的簇引數做為\(\widetilde{\theta}_t\),並和該client對應的簇劃分一起發往server。

(2) server節點執行

  • \(T\)個client接收\(\widetilde{\theta}_1、\widetilde{\theta}_2,..., \widetilde{\theta}_T\)和各client的簇劃分情況

  • 根據節點引數的平均值來更新簇引數:

\[\hat{\theta}_k = \frac{1}{|S_k|} \sum_{t \in S_k}\widetilde{\theta}_t \]

由於現實的聚類關係很模糊,該演算法在具體實現時效仿多工學習中的權值共享[1]機制,允許不同簇(任務)之間共享部分引數。具體地,即在訓練神經網路模型時,先使用所有client的訓練資料學習一個共享表示,然後再執行聚類演算法為每個簇學習神經網路的最後一層(即多工層)。

4.4 WWW 2021:《PFA: Privacy-preserving Federated Adaptation for Effective Model Personalization》[9]

深度多工學習例項1

本文采用的是一次性聚類。本論文需要事先指定聚類簇的個數\(K\),然後在訓練迭代開始前,每個節點先訓練一個稀疏表徵模型模型,然後將稀疏表徵模型得到的稀疏向量傳到server,server再根據各client稀疏向量之間的相似度進行任務簇劃分。之後,各任務簇分別執行傳統FedAvg演算法學習每個簇對應的個性化模型。值得一提的是,和4.1中的演算法一樣,雖然該演算法屬於(個性化)聚類聯邦學習,但我個人認為嚴格來說它與標準的多工學習還有一定距離,因為它的每個簇之間缺少知識共享。 不過它也確實體現了“多工”的思想,而且與我們前面講的聚類演算法關聯十分密切,最終我還是將其加入本專題。

4.5 INFOCOM 2021:《Resource-Efficient Federated Learning with Hierarchical Aggregation in Edge Computing》

多層前饋網路構建共享特徵表示

這篇文章嚴格來說不屬於上面所說的聚類聯邦學習,也不是多工學習/個性化聯邦學習的範疇。其文中提到的clustered應該翻譯成"(區域網構成的)叢集的"。但我覺得這篇文章算是從另一個工程的角度來應用了分組聚合的思想,我在這裡還是覺得將其介紹一下。

在現實工業環境下,聯邦學習常常是基於引數伺服器(parameter server)的,但引數伺服器位於遠端雲平臺上,邊緣節點與它之間的通訊經常不可用,而邊緣節點的數量龐大,這就導致通訊代價可能會很高。本文的基本思想即是減少邊緣節點和PS之間的通訊,加強邊緣節點之間的協作。論文將不同的client劃分為了多個區域網/叢集,每個叢集都有一個leader node(LN)做為叢集頭。每個叢集先分別執行FedAvg演算法將引數聚合到LN(叢集內是同步的),然後再由PS非同步地蒐集各LN的引數並進行聚合。最後再將新的引數廣播到各個邊遠client。

本篇論文還考慮了每個client的資源消耗等複雜資訊,此處為了簡單起見,我們將簡化後演算法的每輪通訊描述如下:

(1) 第\(t\)個client節點執行

  • \(LN_k\)接收引數\(\theta\)做為本地的\(\theta_t\)
  • 執行\(E\)個區域性epoch的SGD:

    \[\theta_t = \theta_t - \eta \nabla \mathcal{l}(\theta_t; b) \]

    (此處將區域性資料\(D_t\)劃分為多個\(b\))
  • 將新的引數\(\widetilde{\theta}_t\)發往所在簇的\(LN_k\)

(2) \(LN_k\)節點執行

  • 從叢集內的\(|S_k|\)個client接收引數\(\{\widetilde{\theta}_j\}(j\in S_k)\)

  • 根據簇內節點引數的加權平均值來更新簇引數:

\[\theta_k = \frac{\sum_{j \in S_k}{ |D_j|}\widetilde{\theta}_j}{\sum_{j\in S_k} |D_j|} \]

  • 將簇引數\(\theta_k\)發往PS。
  • 從PS接收引數\(\theta\)並將其發往client節點。

(2) PS節點執行

  • \(k\)\(LN_k\)節點接收\({\theta}_k\)(非同步地)。

  • 根據一階指數平滑來更新引數:

\[\hat{\theta} = (1-\alpha)\theta + \alpha \theta_k \]

\(\quad\quad\)(在實際論文中\(\alpha\)是一個在迭代中變化的量,此處為了簡化省略)

  • 將新的引數\(\hat{\theta}\)發往\(LN_k\)

引用

  • [1] Caruana R. Multitask learning[J]. Machine learning, 1997, 28(1): 41-75

  • [2] Thrun S, O'Sullivan J. Discovering structure in multiple learning tasks: The TC algorithm[C]//ICML. 1996, 96: 489-497.

  • [3] Bakker B J, Heskes T M. Task clustering and gating for bayesian multitask learning[J]. 2003.

  • [4] Sattler F, Müller K R, Samek W. Clustered federated learning: Model-agnostic distributed multitask optimization under privacy constraints[J]. IEEE transactions on neural networks and learning systems, 2020, 32(8): 3710-3722.

  • [5] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.

  • [6] Wang S, Roosta F, Xu P, et al. Giant: Globally improved approximate newton method for distributed optimization[J]. Advances in Neural Information Processing Systems, 2018, 31.

  • [7] Ghosh A, Maity R K, Mazumdar A. Distributed newton can communicate less and resist byzantine workers[J]. Advances in Neural Information Processing Systems, 2020, 33: 18028-18038.

  • [8] Ghosh A, Chung J, Yin D, et al. An efficient framework for clustered federated learning[J]. Advances in Neural Information Processing Systems, 2020, 33: 19586-19597.

  • [9] Liu B, Guo Y, Chen X. PFA: Privacy-preserving Federated Adaptation for Effective Model Personalization[C]//Proceedings of the Web Conference 2021. 2021: 923-934.

  • [10] Wang Z, Xu H, Liu J, et al. Resource-Efficient Federated Learning with Hierarchical Aggregation in Edge Computing[C]//IEEE INFOCOM 2021-IEEE Conference on Computer Communications. IEEE, 2021: 1-10.

  • [11] Ghosh A, Hong J, Yin D, et al. Robust federated learning in a heterogeneous environment[J]. arXiv preprint arXiv:1906.06629, 2019.

相關文章