LLM並行訓練4-megascale論文學習

SunStriKE發表於2024-06-29

演算法最佳化

並行注意力機制

\[序列版本: y = x + MLP(LayerNorm(x + Attention(LayerNorm(x)))) \]

\[並行版本: y = x + MLP(LayerNorm(x)) + Attention(LayerNorm(x)))) \]

乍一看確實不是等價的, attention那塊的後置mlp去哪了..這個其實沒有理論證明, Palm論文裡提到把mlp融合到attention裡實驗62B模型上效能沒有下降. 主要對應的是下圖網路結構的並行化改造.

image-20240629132922611

滑動視窗Attention

image-20240629163320388

透過堆疊不同大小的視窗來捕獲句子中的資訊,所需要的計算量會比直接計算整個輸入文字長度的計算量要小很多

滑動視窗attention的原理參考這個文章的解釋:因為模型都是多層疊加的,所以層級越高,attend的視野域就越廣。如果w=3,那麼第一層只能注意3個位置,但到第二層能注意到第一層輸出的三個位置,換算到第一層的輸入,就是5個位置。所以隨著層級越高,理論上每個位置注意到的區域就越大,所能儲存的資訊就越接近全域性attention時的狀態

AdamW最佳化(LAMB):

adamW對比adam是把權重衰減項從梯度的計算中拿出來直接加在了最後的權重更新步驟上, 為了把權重衰減和梯度計算解耦(如果加到梯度計算裡會影響到動量的滑動平均), 從而提升最佳化效果.

image-20240629164115223

這裡做的最佳化是新增了一個 \(\phi\)截斷函式, 主要目的是為了防止batch_size太大的時候導致最佳化過程中動量出現極端值影響bp. 這個方法論文裡說可以把batch_size增大4倍從而加速訓練.

\[W_t\leftarrow W_{t-1}-\alpha \cdot\phi(\frac{||W_{t-1}||}{||r_t+\lambda W_{t-1}||})(r_t+\lambda W_{t-1}) \]

3D並行最佳化

張量並行最佳化

image-20240629151955858

序列並行(SP)主要有2個目的: 平攤LayerNorm和Dropout的計算開銷, 而且Activation佔用視訊記憶體也很多, 能夠平攤視訊記憶體消耗.

[!NOTE]

這裡有個疑問: LayerNorm不是要算全域性均值和方差麼..這個拆分後是隻算該裝置內部的均值還是說需要進行額外的allReduce?

AllGather最佳化

序列並行(SP)後, 在進行張量並行(TP)前需要在fp的時候需要先透過gather把之前層的切片從其他節點copy匯聚過來. 如果等gather完成再跑mlp和attention就會讓gpu在通訊這段時間空置等待, 這裡可以最佳化成每通訊完成一個切片後, 進行這個切片的MLP列切分計算, 同時直接把gather結果送給attention平行計算, 最後再把切片計算結果concat到一起. 比如在copy完A0後, A0的前向計算就和A1的通訊並行起來了, 這樣就能儘量的隱藏通訊

另外對矩陣做切片後再進行矩陣乘法, 計算效率要也比2個超大的矩陣乘法要高.

Reduce-Scatter最佳化

這塊是需要把匯聚計算完成的tensor在重新進行切分傳送到序列並行的節點裡, 這裡是把MLP的第二次行切分和attention結果加和給merge到了一起, 完成一個切片的計算後就傳送出去, 同步進行下一個切片的計算使計算和通訊非同步進行.

流水線最佳化

image-20240629171330492

回顧一下交錯式1F1B, 每個節點fp前需要等recv之前layer的結果, 在當前層fp完後, 透過allGather send出去計算完成的資料, 在bp的時候需要透過Reduce-scatter傳送出去計算完的grad.

在warm-up/cool-down過程裡, 都是必須等通訊完成才能進行計算的. 為了縮短等待時間megascale把allGather的recv/send拆分開, recv優先順序高於send, recv後就能直接開始計算, 不需要等send的長尾. 從而縮短等待時間.

在穩定狀態的時候應該和megatron一樣, 通訊都會和計算非同步. 實際情況裡通訊一般都會被隱藏掉(這裡我沒看懂為啥上面畫的對比圖是個純序列的流程)

資料載入最佳化

這章的主要思想工作中經常用到就不細看了, 主要有2部分:

  1. 在bp完同步梯度的時候, 所有前向相關的資料就沒用了, 就可以直接釋放回池預載入下一輪fp需要的embed
  2. 避免單機內多張卡重複讀相同的冗餘資料(這裡可能指的是embed集合麼?), 先在記憶體裡去好重再copy到視訊記憶體

網路通訊最佳化

TODO待補充..網路這塊基本都忘完了.

叢集容錯

image-20240629155710287

錯誤檢測

主要思想和flux-cpu有很多相似點, 主要有以下幾個點

  1. 每個worker定期上報心跳給中心節點, 確保當前狀態正常
  2. 狀態異常時的自動化診斷(NCCL allToAll, allReduce. 同主機RDMA網路卡間的連線和頻寬, 網路卡到GPU/MEM的連線和頻寬), 完成診斷後上報給中心節點.
  3. 中心節點向k8s申請失敗節點的拉黑和重分配替換

狀態恢復

  • checkpoint儲存: 這個看著實現方法和async_patch是一樣的, 先把引數copy到記憶體, 模型繼續訓練. 同步再起一個非同步執行緒用來把記憶體裡的引數寫到hdfs. 這樣就可以把非常耗時的hdfs寫入給隱藏掉.
  • checkpoint讀取: 主要最佳化手段是在同一資料並行組裡的卡, 只選一個GPU對應的訓練執行緒讀hdfs後寫記憶體, 然後透過broadcast給這個資料並行組裡的其他卡. 可以降低hdfs的讀取壓力.

LLM的狀態恢復感覺還挺複雜的, 如果有一個節點掛了在重分配後是所有節點全部回滾到上一個checkpoint還是有更快的方法..pipeline並行應該是在根據節點rank在啟動的時候就分好了層, 節點重入後要替換原來的rank_id.

狀態監控

基於cuda_event的timeline視覺化, 算是老熟人了. 這裡的難點感覺在於超多卡的實時日誌收集, 根據DP來畫出卡和卡的資料流依賴關係

參考:

megascale: https://arxiv.org/abs/2402.15627

Palm(並行attention): https://public.agent-matrix.com/publish/shared/Paper/Palm.pdf

滑動視窗注意力解釋: https://zhuanlan.zhihu.com/p/223430086

相關文章