閱讀筆記(Communication-Efficient Learning of Deep Networks from Decentralized Data)

你看見的我發表於2020-10-20

閱讀筆記(Communication-Efficient Learning of Deep Networks from Decentralized Data)

可訪問文章
可訪問程式碼

動機

傳統資料訓練方式,使用者將本地資料提交給可信雲伺服器進行統一集中並訓練,如何保證使用者本地資料的隱私性
同時,計算成本和通訊開銷如何保證?

聯邦學習(聯合學習)的提出,可以較好地解決上述問題。

  1. 通俗定義:分散式(但有差異)訓練,不需要專門集中資料,避免洩露資料隱私,節省開銷,利用可信第三方聚合上傳的本地引數(或梯度),資料標籤不一致可以通過使用者互動進行樣本對齊。
    特別地,關於可信第三方是否值得信賴的問題,可借鑑半可信(HBC)模型,結合聯邦學習和安全方案(差分隱私、同態加密)。
  2. 優化方向與挑戰:(區別於分散式學習)

– 資料非獨立同分布(Non-IID)
– 資料不平衡(不同使用者參與度不一致,每個使用者擁有的資料樣本數不一致)
–大規模分散式(大規模使用者如何應對以控制開銷)
–通訊限制(使用者傳輸引數(梯度)後多處於離線狀態)

  1. 本文擬解決問題:資料Non-IID、資料不平衡、使用者(客戶端)可用性、優化通訊開銷。
  2. 問題描述與解決方法:
    – Non-IID:
    w w w表示網路引數, i i i代表樣本索引, k k k表示使用者(客戶端)索引(共 K K K個使用者), P k P_k Pk表示使用者k擁有的資料樣本集, n k n_k nk表示使用者 k k k的本地資料集 P k P_k Pk的數量, n n n表示所有資料的數量, f i ( w ) = l ( x i , y i ; w ) f_i(w)=l(x_i,y_i;w) fi(w)=l(xi,yi;w)表示每個樣本的訓練損失, F k ( w ) F_k(w) Fk(w)表示 P k P_k Pk上的平均損失, f ( w ) f(w) f(w)表示聚合所有使用者資料集( P k P_k Pk)的全域性平均損失。
    若每個使用者 k k k的資料集 P k P_k Pk是均勻分佈的,即期望 E P k [ F k ( w ) ] E_{P_k}[F_k(w)] EPk[Fk(w)]=f(w),這意味著 P k P_k Pk符合IID獨立同分布
    –通訊開銷
    若採用資料中心訓練,使用者將本地資料傳送給資料中心,顯然通訊開銷相對較小,主要為計算開銷;
    在聯邦學習場景下,使用者端的本地資料量遠遠小於總資料量,使用者計算開銷可以忽略不記,通訊開銷佔據主導地位,因而,如何降低通訊開銷是需要考慮的問題。
    使用者自願參與聯邦學習,參與通訊輪數十分有限,因此本文採用增加計算開銷的方式減少通訊開銷
    兩種方式:①並行化,採用更多的使用者參與訓練優化;
    ②增加每個使用者的計算開銷,減少通訊輪數。

聯邦平均

FedAvg演算法的計算開銷與三個引數相關;

  1. C:每輪通訊內,執行計算的使用者所佔的比例( 0 ≤ C ≤ 1 0 \le C \le 1 0C1),C=1表示所有使用者均參與聯邦優化, C ⋅ K C \cdot K CK表示參與使用者數量,C=0表示有且僅有1個使用者均參與聯邦優化。
  2. E:每輪通訊內,每個使用者對其本地資料集進行的訓練迭代次數( E ≥ 1 E \ge 1 E1);
  3. B:使用者更新訓練本地資料集的小批量大小( B ≥ 1 B \ge 1 B1), B = ∞ B=\infty B=表示採用單個批量處理整個本地資料集。
    特別地,在每輪通訊內,使用者 k k k執行本地更新的次數 u K = E ⋅ n k / B u_K=E \cdot n_k / B uK=Enk/B,若 B = ∞ B=\infty B=,表示 n k / B = 1 n_k / B=1 nk/B=1,即 u K = E u_K=E uK=E

FedSGD = FedAvg ( B = ∞ B=\infty B= E = 1 E=1 E=1)

在這裡插入圖片描述
一輪通訊並意味著挑選的使用者只執行一次迭代訓練,迭代次數與E有關;

演算法解釋:m表示隨機挑選出的參與聯邦優化的使用者數量,且 m ≥ 1 m \ge 1 m1,在第t輪通訊,隨機使用者構成子集 S t S_t St
S t S_t St中的每個使用者 k k k並行地利用全域性聚合引數 w t w_t wt執行引數更新,獲得 w t + 1 k w_{t+1}^k wt+1k
– 使用者 k k k根據批處理大小 B B B將本地資料集 P k P_k Pk拆分為 B \mathcal{B} B個批量,在每個迭代中,按部就班地根據每個批量更新權重引數。
伺服器將所有使用者 k k k的更新引數使用者 k k k執行加權聚合獲得 w t + 1 w_{t+1} wt+1,權重為每個使用者擁有的資料樣本佔該輪通訊的所有資料樣本的比例;

公式推導:使用者 k k k計算本地資料集上的平均梯度 g k = ∇ F k ( w t ) g_k=\nabla F_k(w_t) gk=Fk(wt),中心伺服器可以通過聚合這些梯度更新權重引數
w t + 1 = w t − η ∇ f ( w t ) ← w t − η ∑ k = 1 K n k n g k w_{t+1} = w_{t} - \eta \nabla f(w_t) \leftarrow w_{t} - \eta \sum_{k=1}^K \frac{n_k}{n} g_k wt+1=wtηf(wt)wtηk=1Knnkgk。(使用者上傳梯度
等價於:
使用者通過區域性平均梯度更新本地網路引數
w t + 1 k = w t − η g k w_{t+1}^k = w_{t} - \eta g_k wt+1k=wtηgk
伺服器通過全域性平均聚合這些網路引數,更新獲得全域性的網路引數
w t + 1 = ∑ k = 1 K n k n w t + 1 k w_{t+1} = \sum_{k=1}^K \frac{n_k}{n} w_{t+1}^k wt+1=k=1Knnkwt+1k。(使用者上傳網路引數

神經網路是非凸的,對於一般的非凸目標,引數空間上的平均模型可能產生任意差的模型。
初步解決方案:共享相同的隨機種子,即採取相同的引數初始化。

實驗結果

目標是為了評估本文的聯邦優化方法,而不是追求最高的分類精度

受到影像分類和語言建模任務的激勵,好的模型可以大大提高移動裝置的可用性。
對於每一個任務,我們首先選擇一個足夠小的代理資料集,以便我們可以徹底研究FedAvg演算法的超引數。
FedAvg演算法如何為使用者劃分資料集?

本文圍繞影像分類和語言模型兩個任務開展實驗,下面主要介紹並解釋影像分類實驗。
兩類真實資料集:

  1. MINIST(60000訓練集(55000訓練+5000驗證)+10000測試集)
  2. CIFAR-10(50000訓練集+10000測試集)

兩個網路:

  1. 2NN
    –Input(28×28×1) → \rightarrow FC(200) + ReLU → \rightarrow FC(200) + ReLU → \rightarrow Output(FC10)
  2. CNN
    –Input(28×28×1) → \rightarrow Conv(5×5) → \rightarrow 24×24×32 → \rightarrow Max-pool (2×2) → \rightarrow 12×12×32 → \rightarrow Conv(5×5) → \rightarrow 12×12×64 → \rightarrow Max-pool (2×2) → \rightarrow 6×6×64 → \rightarrow FC(512) + ReLU → \rightarrow Output(FC10) → \rightarrow Softmax (10)

兩種MINIST資料分佈:

  1. IID
    –打亂資料,100個使用者,每個使用者分配6000個樣本(含完整標籤0-10)
  2. Non-IID
    –首先按數字標籤對資料進行排序,將其劃分為大小為300的200個碎片,指定100個使用者,為每個使用者分配2個碎片(即100×600)。

MINIST優化實驗

探索使用者批量與通訊輪數(精度提升快慢或損失函式收斂速度)的關係
(2NN達到目標精度97%(E=1),CNN達到目標精度99%(E=5))
(基線:C=0,表示單個使用者參與優化)
在這裡插入圖片描述
結果表明:當 B = ∞ B=\infty B=時(使用者所有資料一次性地迭代優化),隨著通訊輪數增加,精度提升對參與使用者數量不敏感;
B = 10 B=10 B=10時(小批量),隨著通訊輪數增加,精度提升對參與使用者數量十分敏感(特別是Non-IID資料);
∗ ∗ C = 0.1 ∗ ∗ **C=0.1** C=0.1時(每輪通訊10個使用者參與優化),可以達到計算效率和收斂速度的平衡(min(C*通訊輪數),即最小化使用者總計算量)。

探索使用者計算量(增加使用者開銷)與通訊輪數(精度提升快慢或損失函式收斂速度)的關係
(CNN達到目標精度99%(固定C=0.1, 每輪通訊10個使用者參與聯邦優化))
如何增加每輪通訊內的使用者開銷?增加迭代次數E;減少小批量B;均可以增加迭代更新次數 u u u
n n n=60000, K K K=100, n k n_k nk=600, u K = E ⋅ n k / B u_K=E \cdot n_k / B uK=Enk/B
在這裡插入圖片描述
隨著
每輪通訊
使用者執行的更新次數(使用者承擔的計算量)越多,那麼達到目標精度所需要的通訊輪數越少。
當E=5, B=50, u=60,可以達到使用者計算效率和收斂速度的平衡。(min(u*通訊輪數),最小化總更新次數)

在這裡插入圖片描述
在這裡插入圖片描述
在這裡插入圖片描述

CIFAR-10優化實驗

基線SGD:不分使用者塊,小批量=100;
FEDSGD: B = ∞ B=\infty B= E = 1 E=1 E=1,(固定 C = 0.1 C=0.1 C=0.1);
FEDAVG: B = 50 B=50 B=50 E = 5 E=5 E=5,(固定 C = 0.1 C=0.1 C=0.1);
在這裡插入圖片描述

對於適當的C和E值,FedAvg在每次小批計算中取得了類似的進展。
標準的SGD和FedAvg在每輪只有一個客戶(C = 0)情況下,在準確性上都表現出了顯著的波動,而對更多客戶進行平均則消除了這種波動。
在這裡插入圖片描述

學習率與優化演算法的關係曲線
在這裡插入圖片描述

結論:

實驗表明,聯邦學習是可行的,FedAvg使用相對較少的通訊次數訓練高質量的模型。
雖然聯邦學習提供了許多實用的隱私好處,但通過差分隱私安全多方計算或它們的組合提供更強的資料隱私保證是未來工作的一個有趣方向。

下一篇閱讀筆記:

  1. 聯邦學習場景下的安全聚合協議 → \rightarrow Practical Secure Aggregation for Federated Learning on User-Held Data

  2. 聯邦學習場景下的提高通訊效率策略 → \rightarrow FEDERATED LEARNING: STRATEGIES FOR IMPROVING COMMUNICATION EFFICIENCY

相關文章