聯邦學習中的優化演算法

orion發表於2022-03-04

導引

聯邦學習做為一種特殊的分散式機器學習,仍然面臨著分散式機器學習中存在的問題,那就是設計分散式的優化演算法。

以分散式機器學習中常採用的client-server架構(同步)為例,我們常常會將各client節點計算好的區域性梯度收集到server節點進行求和,然後再根據這個總梯度進行權重更新。

同步迭代框架

不過相比傳統的分散式機器學習,它需要關注系統異質性(system heterogeneity)、統計異質性(statistical heterogeneity)和資料隱私性(data privacy
)
。系統異質性體現為昂貴的通訊代價和節點隨時可能宕掉的風險(容錯);統計異質性資料的不獨立同分布(Non-IID)和不平衡。由於以上限制,傳統分散式機器學習的優化演算法便不再適用,需要設計專用的聯邦學習優化演算法。

舉個例子,傳統分散式機器學習中也提出了許多降低通訊量的演算法,包括近似牛頓法[1][2][3]、小樣本平均[5]等,但這些演算法只考慮了資料IID的情況,不能照搬過來。演算法[4]沒有假設資料IID,但是不適用深度學習,因為神經網路很難求對偶問題。

目前已經針對聯邦學習提出了許多新的優化演算法。同時,同時除了中心(centralized)化優化演算法,針對聯邦學習的去中心化(decentralized)優化演算法也得到了廣泛研究。

FedAvg——旨在減少通訊的開山之作

在聯邦學習中,首先的不同便是通訊代價。我們希望每輪通訊能夠在client上完成更多的運算(client端一般是使用者手機等裝置,充電的時候都可以計算),這也是聯邦學習開山論文[6]提出的FedAvg演算法的初衷。該演算法是聯邦學習領域最為基礎的梯度聚合方法。

相比傳統分散式機器學習方法在client節點只計算出梯度,FedAvg方法希望client節點能夠多做一些運算,得到比梯度更好的下降方向。由於這個下降方向比梯度更好,所以可以收斂更快。而收斂快了,那麼通訊次數自然就少了。這就是該演算法設計的基本想法。

該演算法的每輪通訊描述如下:

(1) 第\(k\)個client節點執行

  • 從server接收全域性模型引數\(w^t\)並令\(w_k=w^t\)
  • 執行\(E\)個區域性epoch的SGD:

    \[w_k = w_k - \eta \nabla \mathcal{l}(w_k; b) \]

    (此處將區域性資料\(D_k\)劃分為多個\(b\))
  • 將新的\(w_k\)發往server。

(2) server節點執行

  • \(K\)個client接收\(w_1^{t+1}、w_2^{t+1},...w_K^{t+1}\)

  • 按(加權)平均更新模型引數:

\[w^{t+1} = \sum_{k=1}^K\frac{n_k}{n}w_k^{t+1} \]

其中\(t\)為第\(t\)輪迭代。可以看到相比傳統分時機器學習中每個client計算完梯度就發給server,FedAvg計算完梯度後會直接更新區域性引數,同時重複該過程多次。而對於server,會對client傳來的引數進行加權平均。

注意,FedAvg還有一種變種寫法如下:

(1) 第\(k\)個client節點執行

  • 從server接收全域性模型引數\(w^t\)並令\(w_k = w^t\)
  • 執行\(E\)個區域性epoch的SGD:

    \[w_k = w_k - \eta \nabla \mathcal{l}(w_k; b) \]

    (此處將區域性資料\(D_k\)劃分為多個\(b\))
  • \(g_k = w_k-w^t\)發往server。

(2) server節點執行

  • \(K\)個client接收\(g_1、g_2,...g_K\)

  • 按(加權)平均更新模型引數:

\[w^{t+1} = w^t + \sum_{k=1}^K\frac{n_k}{n}g_k \]

兩種寫法本質上等效的。

綜上所述,FedAvg演算法在通訊次數相同的情況下,自然會收斂更快。如果實驗對比FedAvg和傳統分散式機器學習的SGD,我們會得到這樣的結果:

同步迭代框架

不過這麼做是有代價的,當client節點的計算量(以epoch來衡量)相同,那麼FedAvg的收斂速度是不如傳統SGD的。

同步迭代框架

這是典型的以計算換通訊策略。 而聯邦學習中計算代價小,通訊代價大,因此FedAvg演算法很有用。該演算法的作者已證明,FedAvg能夠在Non-IID條件下收斂[7]。論文[8]以Gboard輸入法背景下的單次預測任務為例,從工程上證明了FedAvg演算法的優越性。

FedProx——關注掉隊者

FedProx[9]主要從系統異質性和統計異質性兩個方向入手來改良FedAvg演算法。不過,介於後來FedAvg演算法已被證明在Non-IID資料集上本來能收斂[7],該演算法的貢獻還是在於提供了一個收斂更快、效果更好的演算法。

我們知道,系統異質性下,FedAvg演算法要求的每個節點執行的\(E\)個epoch的區域性迭代可能無法得到保證,因為節點隨時可能宕掉。FedProx作者還探究了統計異質性和系統異質性之間的相互作用,並認為系統異質性產生的掉隊者(stragglers) 以及掉隊者發往server的帶偏差的引數資訊會進一步增加統計異質性,最終影響收斂。因此,作者提出在client的優化目標函式中增加一個近端項,這樣可以使優化演算法更加穩定,最終使得FedProx在統計異質性下也收斂更快。

FedProx中server的操作和FedAvg相同,都是採用(加權)平均,但是其第\(k\)個client端不再是執行\(E\)輪的SGD,而是求解以下帶近端項的優化問題:

\[\begin{aligned} w^{t+1}_k &=\underset{w}{\text{argmin}}h(w,w^t_k)\\ & = \mathcal{l}(w, D_k) + \frac{\mu}{2}||w-w^t_k||^2 \end{aligned} \]

其中\(\mathcal{l}(w, D_k)\)為客戶端原本的優化函式,\(\frac{\mu}{2}||w-w^t||^2\)為近端項。作者在論文中證明了近端項的新增能夠使FedProx更好地適用於統計異質和系統異質的環境。

我們可以認為FedAvg是FedProx中將\(\mu\)設定為0,求解器設定為SGD,Epoch設定為\(E\)(當然這樣就無法處理系統異質性)的特殊情況。

FedAvg+ ———用元學習做模型個性化

在傳統的聯邦學習中,每個client節點聯合聯合訓練出各一個全域性的模型(在前文中即server節點的\(w_t\))。但是由於資料Non-IID,訓練出的全域性模型很難對每個區域性節點都適用,不夠“個性化”。

Jiang Y等人的這篇論文[10]首次採用了個性化聯邦學習的思路:不求訓練出一個全域性的模型,而使每個節點訓練各不相同的模型。作者在論文中採用模型不可知的元學習(Model Agnostic Meta Learning, MAML) 思路。元學習在給定小樣本例項的條件下進行自適應,可以優化在異構任務上的表現,它由兩個步驟構成:“meta training”——訓練初始化模型/元模型和“meta testing”——使初始化模型在特定的任務下完成自適應

作者認為傳統的FedAvg演算法[6]可以被解釋為一種元學習演算法。在此基礎上再進行仔細的微調(fine-tuning)能夠使全域性模型少一些泛化性,但同時能夠更容易個性化。我們將全域性的已訓練的模型稱為初始化模型(initial model),將區域性的已訓練模型稱為個性化模型(personalized model)。論文沒有采用[11]中將訓練初始化模型和模型個性化的操作分離,作者認為這樣會陷入區域性最優,作者提出的演算法包括以下3個連續的步驟:

(1) 執行傳統的FedAvg演算法得到初始化模型,其中採用更大的\(E\),並使用帶動量的SGD做為優化器。

(2) 採用FedAvg的變種演算法對初始化模型進行微調: 此時採用Adam做為優化器,且迭代不再是採用\(E\)個epoch,而是先從\(D_k\)中隨取樣\(M\)個資料集\(\{D_{k,m}\}\)(\(M\)一般較小),然後進行如下的\(M\)個迭代步:

\[w_k = w_k - \eta \nabla \mathcal{l}(w_k; D_{k, m}) \]

(3) 對client進行進行個性化操作,採用和訓練期間相同的優化器。

作者認為該演算法能夠得到更穩固的初始化模型,這樣對於一些clients只有有限的甚至沒有資料來做個性化的情況很有好處。

Clustered FL——多工知識共享

聚類聯邦學習(CFL)[12]這篇論文針對資料Non-IID導致的區域性最優,提出了一種新的聯邦學習個性化方法:聚類(多工)聯邦學習。

CFL保持著個性化聯邦學習的基本假設:每個節點訓練各不相同的模型。但並沒有採用元學習中初始化模型+自適應的措施,而是借用多工學習中的常見手段,即讓節點在訓練的過程中就進行知識共享(可以參見我的部落格《基於正則表示的多工學習》),而無需另設一個初始化模型。更具體的,CFL採用的是聚類多工學習(clustered multitask learning),在訓練的過程中將引數相似的節點劃分為同一個任務簇,同一個任務簇共享引數的變化量\(g\),以此既能達到完成知識共享和相似的節點相互促進的目的。

聚類聯邦學習演算法的每輪通訊描述如下:

(1) 第\(k\)個client節點執行

  • 從server接收\(g_{c(k)}\)
  • \(w_{old}=w_k=w_k + g_{c(k)}\)
  • 執行\(E\)個區域性epoch的SGD:

    \[w_k = w_k - \eta \nabla \mathcal{l}(w_k; b) \]

    (此處將區域性資料\(D_k\)劃分為多個\(b\))
  • \(g_k = w_k-w_{old}\)發往server。

(2) server節點執行

  • \(K\)個client接收\(g_1、g_2,...g_K\)

  • 對每一個簇\(c\in \mathcal{C}\),計算簇內平均引數變化:

\[g_{c} = \frac{1}{|c|}\sum_{k\in c}g_k \]

  • 根據不同節點引數變化量的餘弦距離\(\alpha_{i,j}=\frac{\langle g_i, g_k\rangle}{||g_i||||g_j||}\)重新劃分聚類簇。

CFL的簇劃分演算法採用的是不斷進行二分裂的方式,無需指定簇的數量做為先驗。該演算法最重要的貢獻就是簇間知識共享思想的引入(並不共享引數,而共享引數的變化量,注意和論文[15]中直接引數的平均進行區分)。

pFedMe—純優化視角的個性化

pFedMe[13]這篇論文繼續瞄準聯邦學習個性化,它的創新點是使用Moreau envelope(也稱Moreau-Yosida正則化)做為client的正則損失函式。該演算法比已有的許多演算法收斂速度更快。

這個方法的一大貢獻將個性化模型與全域性模型同時進行優化求解,該方法按照與標準FedAvg相似的方法來更新全域性模型(多了個一階指數平滑),不過會以更低的複雜度來對個性化模型進行優化。

該篇論文演算法的每輪通訊描述如下:

(1) 第\(k\)個client節點執行

  • 從server接收全域性模型引數\(w^t\)並令\(w_k = w^t\)

  • 執行\(R\)輪區域性迭代:

    \[w_k = w_k - \eta \mu(w_k - \hat{\theta}_k(w_k)) \]

    其中

    \[ \hat{\theta}_k(w_k)= \underset{\theta_k \in \mathbb{R}^d}{\text{argmin}} \{ \mathcal{l}(\theta_k, D_k) + \frac{\mu}{2}||\theta_k - w_k||^2 \} \]

  • 將新的\(w_k\)發往server。

(2) server節點執行

  • \(K\)個client接收\(w_1^{t+1}、w_2^{t+1},...w_K^{t+1}\)

  • 按與平均值的一次指數平滑更新模型引數:

    \[w^{t+1} = (1-\beta)w^t + \beta \frac{1}{K} \sum_{k=1}^K w_k^{t+1} \]

其中重點在於client每輪區域性迭代中求解Moreau envelope的部分,即求解\(\hat{\theta}_k(w_k)\)的部分。這裡\(\theta_k\)表示第\(k\)個client的個性化模型,\(\mu\)引數用於控制全域性模型引數\(w_k\)相對於個性化模型的強度。其中Moreau envelope部分可以採用任意迭代方法求解。

FedEM—混合分佈假設與EM演算法

FedEM[14]這篇論文另闢蹊徑,沒有關注模型的個性化,而是考慮從優化演算法上去提高聯邦學習模型的精度,其中採用的手段有兩點,一點是基於client節點資料滿足混合分佈的假設,使每個client節點訓練由\(M\)個子模型整合所得的模型;二點是針對混合分佈的假設,採用EM演算法來做引數估計,提高了模型的整體精度。

該演算法中心化形式的每輪通訊描述如下:

(1) 第\(k\)個client節點執行

  • 從server接收全域性模型引數\(w^t\)並令\(w_k = w^t\)

  • 對每一個模型成分\(m\)(\(m=1,..., M\))以及每一個區域性樣本\(i\)(\(i=1,...,n_t\))執行\(\text{E}\)步驟

    \[q_k(z^{(i)}_k=m)\leftarrow \frac{\pi _{km}\cdot \text{exp}\left(-\mathcal{l}(h_{w_{km}}(x_k^{(i)}), y_k^{(i)})\right)} {\sum_{m'=1}^M \pi _{km'}\cdot \text{exp}\left(-\mathcal{l}(h_{w_{km'}}(x_k^{(i)}), y_k^{(i)})\right)} \]

    對每一個模型成分\(m\)執行\(\text{M}\)步驟

    \[ \pi_{km} = \frac{\sum_{i=1}^{n_t} q_k(z^{(i)}_k=m)}{n_t} \]

    對每一個模型成分\(m\)執行\(J\)輪區域性迭代:

    \[w_{km} = w_{km} - \eta_j\sum_{i\in \mathcal{I}}q_k(z^{(i)}_k=m)\cdot \nabla_{w_{km}}\mathcal{l}(h_{w_{km}}(x_k^{(i)}), y_k^{(i)}) \]

    \(\mathcal{I}\)為每輪迭代有放回地從\(1,2,...|D_k|\)中採的隨機樣本索引集合

  • 將新的\(w_k\)發往server。

(2) server節點執行

  • \(K\)個client接收\(w_1^{t+1}、w_2^{t+1},...w_K^{t+1}\)

  • 對每一個模型成分\(m\)按(加權)平均更新模型引數:

\[w^{t+1}_m = \sum_{k=1}^K\frac{n_k}{n}w_{km}^{t+1} \]

該演算法的中心化形式在許多資料集上精度都取得了SO他的水平。

該演算法的去中心化形式的每輪通訊描述如下:

\(k\)個client節點執行:

  • 對每一個模型成分\(m\)(\(m=1,..., M\))以及每一個區域性樣本\(i\)(\(i=1,...,n_t\))執行\(\text{E}\)步驟

    \[q_k(z^{(i)}_k=m)\leftarrow \frac{\pi _{km}\cdot \text{exp}\left(-\mathcal{l}(h_{w_{km}}(x_k^{(i)}), y_k^{(i)})\right)} {\sum_{m'=1}^M \pi _{km'}\cdot \text{exp}\left(-\mathcal{l}(h_{w_{km'}}(x_k^{(i)}), y_k^{(i)})\right)} \]

    對每一個模型成分\(m\)執行\(\text{M}\)步驟

    \[ \pi_{km} = \frac{\sum_{i=1}^{n_t} q_k(z^{(i)}_k=m)}{n_t} \]

    對每一個模型成分\(m\)執行\(J\)輪區域性迭代:

    \[w_{km} = w_{km} - \eta_j\sum_{i\in \mathcal{I}}q_k(z^{(i)}_k=m)\cdot \nabla_{w_{km}}\mathcal{l}(h_{w_{km}}(x_k^{(i)}), y_k^{(i)}) \]

    \(\mathcal{I}\)為每輪迭代有放回地從\(1,2,...|D_k|\)中採的隨機樣本索引集合

  • 將新的\(w_k\)發往其鄰居節點

  • 從鄰居節點接收新的\(w_k\)

  • 對每一個模型成分\(m\)按(加權)平均更新模型引數:

\[w^{t+1}_m = \sum_{k=1}^K\lambda_{km}w_{km}^{t+1} \]

(其中加權引數\(\lambda\)為隨機初始化)

引用

  • [1] Shamir O, Srebro N, Zhang T. Communication-efficient distributed optimization using an approximate newton-type method[C]//International conference on machine learning. PMLR, 2014: 1000-1008.

  • [2] Wang S, Roosta F, Xu P, et al. Giant: Globally improved approximate newton method for distributed optimization[J]. Advances in Neural Information Processing Systems, 2018, 31.

  • [3] Mahajan D, Agrawal N, Keerthi S S, et al. An efficient distributed learning algorithm based on effective local functional approximations[J]. arXiv preprint arXiv:1310.8418, 2013.

  • [4] Smith V, Forte S, Chenxin M, et al. CoCoA: A general framework for communication-efficient distributed optimization[J]. Journal of Machine Learning Research, 2018, 18: 230.

  • [5] Zhang Y, Duchi J, Wainwright M. Divide and conquer kernel ridge regression: A distributed algorithm with minimax optimal rates[J]. The Journal of Machine Learning Research, 2015, 16(1): 3299-3340.

  • [6] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.

  • [7] Stich S U. Local SGD converges fast and communicates little[C]///International Conference on Learning Representations, 2018.

  • [8] Hard A, Rao K, Mathews R, et al. Federated learning for mobile keyboard prediction[J]. arXiv preprint arXiv:1811.03604, 2018.

  • [9] Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia
    Smith. “Federated Optimization in Heterogeneous Networks”. In: Third MLSys Conference.2020.

  • [10] Jiang Y, Konečný J, Rush K, et al. Improving federated learning personalization via model agnostic meta learning[J]. arXiv preprint arXiv:1909.12488, 2019.Presented at NeurIPS FL workshop 2019.

  • [11] Sim K C, Zadrazil P, Beaufays F. An investigation into on-device personalization of end-to-end automatic speech recognition models[J]. In Interspeech, 2019.

  • [12] Sattler F, Müller K R, Samek W. Clustered federated learning: Model-agnostic distributed multitask optimization under privacy constraints[J]. IEEE transactions on neural networks and learning systems, 2020, 32(8): 3710-3722.

  • [13]
    T Dinh C, Tran N, Nguyen J. Personalized federated learning with moreau envelopes[J]. Advances in Neural Information Processing Systems, 2020, 33: 21394-21405.

  • [14] Marfoq O, Neglia G, Bellet A, et al. Federated multi-task learning under a mixture of distributions[J]. Advances in Neural Information Processing Systems, 2021, 34.

  • [15]
    Liu B, Guo Y, Chen X. PFA: Privacy-preserving Federated Adaptation for Effective Model Personalization[C]//Proceedings of the Web Conference 2021. 2021: 923-934.

相關文章