切分方式
前置知識
矩陣乘法求導
以下定義X的dim為(M,K), W的dim為(K, N), 平均切分z次
行式切分
forward
先把X按列切分每個子塊的dim都是 (M, K/z), W1的dim(K/z, N), 這裡利用了分塊矩陣乘法的性質, 把切分好的Xi scatter到對應W的卡上, 計算完成後相加結果矩陣即可拿到Y的前向結果
backward:
Y對Yi的偏導因為 Y= Y1 + Y2求導偏導是1, 可以直接省略. 只需要把L對Y的偏導廣播到W1, W2各自的卡上, 他們就能各自計算對應的梯度來更新W. L對X的偏導也是兩張卡各自計算後(L對Y的偏導 * Wi的轉置), 最後按列concat到一起就能得到最終X的偏導
列式切分
forward:
因為按列切分沒有改變矩陣乘法的中間dim, 前向只需要concat起來兩個切分後的乘法結果
backward:
這裡是需要先把L對Y的導數切分後再傳給各張卡, L對W的偏導計算方法和行切分一樣, L對X的偏導因為對於損失L,X既參與了XW1的計算,也參與了XW2的計算, 所以需要把兩張卡上對X1,X2的偏導求和. 得到最終的結果
MLP並行
以Y = GELU(X * A) * B 為例
forward: 把引數A進行列切分, B進行行切分. 先把X廣播到每張卡上, 每張卡直接算完從A->B的所有流程後, AllReduce計算結果就能得到Y
Backward: 把Grad(y)廣播到各張卡上獨立反向, 然後allreduce所有的grad(xi), 就能得到grad(x)
這個設計真挺巧妙的. 如果我們只用行切分或者列切分, 在兩個矩陣計算的中間必然會進行一次集合通訊的同步. 列切分是AllGather, 行切分是AllReduce. 然而先行後列, 中間除了節省掉集合通訊的成本, 連第二次列切分的時候需要先對X做分塊操作的步驟都節省了. 牛啊
MultiHeadAttention並行
如果有兩個頭兩張卡, 把V,Q,K權重矩陣進行列切分後. 算出來的Q1,Q2 透過concat就能得到Q, 完美的切分了資料和算力..真的感覺天然適配張量並行, 只要我們保證head數能整除卡數就能完全利用起來所有的卡.
總結
張量並行結合了分塊矩陣運算的性質, 透過合理的切分輸入和引數, 再加上行列切分的合理配置. 就能節省掉很多過程中的不必要通訊和冗餘計算. 而且對效果無損, 看的過程中感覺好神奇.
參考
https://zhuanlan.zhihu.com/p/622212228