dense並行訓練2-張量並行

SunStriKE發表於2024-06-27

切分方式

前置知識

矩陣乘法求導

\[Y=f(AB)=f(C) \]

\[\frac{\partial Y}{\partial A} = \frac{\partial Y}{\partial C} \cdot B^{T} \]

\[\frac{\partial Y}{\partial B} = A^{T} \cdot \frac{\partial Y}{\partial C} \]

以下定義X的dim為(M,K), W的dim為(K, N), 平均切分z次

行式切分

image-20240627150212684

forward

\[Y= X_1W_1 + X_2W_2 \]

\[X= concat(X_1, X_2, axis=1) \]

\[W = concat(W_1, W_2, axis= 0) \]

先把X按列切分每個子塊的dim都是 (M, K/z), W1的dim(K/z, N), 這裡利用了分塊矩陣乘法的性質, 把切分好的Xi scatter到對應W的卡上, 計算完成後相加結果矩陣即可拿到Y的前向結果

backward:

\[\frac{\partial L}{\partial W_i} = \frac{\partial L}{\partial Y}\cdot \frac{\partial Y}{\partial Y_i}\cdot \frac{\partial Y_i}{\partial W_i} \\ \]

Y對Yi的偏導因為 Y= Y1 + Y2求導偏導是1, 可以直接省略. 只需要把L對Y的偏導廣播到W1, W2各自的卡上, 他們就能各自計算對應的梯度來更新W. L對X的偏導也是兩張卡各自計算後(L對Y的偏導 * Wi的轉置), 最後按列concat到一起就能得到最終X的偏導

列式切分

image-20240627151744984

forward:

\[Y= concat(X_1W_1, X_2W_2, axis=1) \\ \]

因為按列切分沒有改變矩陣乘法的中間dim, 前向只需要concat起來兩個切分後的乘法結果

backward:

\[\frac{\partial L}{\partial W_i} = \frac{\partial L}{\partial Y}\cdot \frac{\partial Y_i}{\partial W_i} \]

\[\frac{\partial L}{\partial X} = \frac{\partial L}{\partial X_1} + \frac{\partial Y_i}{\partial X_2} \\ \]

這裡是需要先把L對Y的導數切分後再傳給各張卡, L對W的偏導計算方法和行切分一樣, L對X的偏導因為對於損失L,X既參與了XW1的計算,也參與了XW2的計算, 所以需要把兩張卡上對X1,X2的偏導求和. 得到最終的結果

MLP並行

以Y = GELU(X * A) * B 為例

image-20240627165655511

forward: 把引數A進行列切分, B進行行切分. 先把X廣播到每張卡上, 每張卡直接算完從A->B的所有流程後, AllReduce計算結果就能得到Y

Backward: 把Grad(y)廣播到各張卡上獨立反向, 然後allreduce所有的grad(xi), 就能得到grad(x)

這個設計真挺巧妙的. 如果我們只用行切分或者列切分, 在兩個矩陣計算的中間必然會進行一次集合通訊的同步. 列切分是AllGather, 行切分是AllReduce. 然而先行後列, 中間除了節省掉集合通訊的成本, 連第二次列切分的時候需要先對X做分塊操作的步驟都節省了. 牛啊

MultiHeadAttention並行

image-20240627170918832

如果有兩個頭兩張卡, 把V,Q,K權重矩陣進行列切分後. 算出來的Q1,Q2 透過concat就能得到Q, 完美的切分了資料和算力..真的感覺天然適配張量並行, 只要我們保證head數能整除卡數就能完全利用起來所有的卡.

總結

張量並行結合了分塊矩陣運算的性質, 透過合理的切分輸入和引數, 再加上行列切分的合理配置. 就能節省掉很多過程中的不必要通訊和冗餘計算. 而且對效果無損, 看的過程中感覺好神奇.

相關文章