聯邦平均演算法(Federated Averaging Algorithm,FedAvg)

MaplesWCT發表於2022-04-18

設一共有\(K\)個客戶機,

中心伺服器初始化模型引數,執行若干輪(round),每輪選取至少1個至多\(K\)個客戶機參與訓練,接下來每個被選中的客戶機同時在自己的本地根據伺服器下發的本輪(\(t\)輪)模型\(w_t\)用自己的資料訓練自己的模型\(w^k_{t+1}\),上傳回伺服器。伺服器將收集來的各客戶機的模型根據各方樣本數量用加權平均的方式進行聚合,得到下一輪的模型\(w_{t+1}\)

\[\begin{aligned} & \qquad w_{t+1} \leftarrow \sum^K_{k=1} \frac{n_k}{n} w^k_{t+1} \qquad\qquad //n_k為客戶機k上的樣本數量,n為所有被選中客戶機的總樣本數量\\ \end{aligned} \]

【虛擬碼】

\[\begin{aligned} & 演算法1:Federated\ Averaging演算法(FedAvg)。 \\ & K個客戶端編號為k;B,E,\eta分別代表本地的minibatch\ size,epochs,學習率learning\ rate \\ & \\ & 伺服器執行:\\ & \quad 初始化w_0 \\ & \quad for \ 每輪t=1,2,...,do \\ & \qquad m \leftarrow max(C \cdot K,1) \qquad\qquad //C為比例係數 \\ & \qquad S_t \leftarrow (隨機選取m個客戶端) \\ & \qquad for \ 每個客戶端k \in S_t 同時\ do \\ & \qquad \qquad w^k_{t+1} \leftarrow 客戶端更新(k,w_t) \\ & \qquad w_{t+1} \leftarrow \sum^K_{k=1} \frac{n_k}{n} w^k_{t+1} \qquad\qquad //n_k為客戶機k上的樣本數量,n為所有被選中客戶機的總樣本數量\\ & \\ & 客戶端更新(k,w): \qquad \triangleright 在客戶端k上執行 \\ & \quad \beta \leftarrow (將P_k分成若干大小為B的batch) \qquad\qquad //P_k為客戶機k上資料點的索引集,P_k大小為n_k \\ & \quad for\ 每個本地的epoch\ i(1\sim E) \ do \\ & \qquad for\ batch\ b \in \beta \ do \\ & \qquad \qquad w \leftarrow w-\eta \triangledown l(w;b) \qquad\qquad //\triangledown 為計算梯度,l(w;b)為損失函式\\ & \quad 返回w給伺服器 \end{aligned} \]

為了增加客戶機計算量,可以在中心伺服器做聚合(加權平均)操作前在每個客戶機上多迭代更新幾次。計算量由三個引數決定:

  • \(C\),每一輪(round)參與計算的客戶機比例。
  • \(E(epochs)\),每一輪每個客戶機投入其全部本地資料訓練一遍的次數。
  • \(B(batch size)\),用於客戶機更新的batch大小。\(B=\infty\)表示batch為全部樣本,此時就是full-batch梯度下降了。

\(E=1\ B=\infty\)時,對應的就是FedSGD,即每一輪客戶機一次性將所有本地資料投入訓練,更新模型引數。

對於一個有著\(n_k\)個本地樣本的客戶機\(k\)來說,每輪的本地更新次數為\(u_k=E\cdot \frac{n_k}{B}\)

參考文獻:

  1. H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. Y. Arcas, “Communication-efficient learning of deep networks from decentralized data,” in Proc. AISTATS, 2016, pp. 1273–1282.

相關文章