LLM並行訓練1-流水線並行

SunStriKE發表於2024-06-26

並行訓練-流水線

簡述

並行訓練主要有三種策略:

  1. 資料並行訓練加速比最高,但要求每個裝置上都備份一份模型,視訊記憶體佔用比較高,通訊量大
  2. 張量並行,通訊量比較高,適合在機器內做模型並行。
  3. 流水線並行,訓練裝置容易出現空閒狀態,加速效率沒有DP高;但能減少通訊邊界支援更多的層數,適合在機器間使用。

流水線並行

Micro-batch(Gpipe)

將網路結構進行縱向拆分, 每張卡訓練其中的幾層. 下圖把網路拆成4層. 如果是按照純粹的mini-batch訓練, 每層之間是純序列的. 後面的卡會始終等待前面的卡. 所以引入了micro-batch的概念. 把mini-batch進行更細粒度的拆分, 這樣在完成batch0的fp之後, 卡0可以進行batch1的fp, 卡1就能開始batch0的fp. 從而提高並行度.

image-20240626203004792

存在的問題:

  1. 存在bubble_time: 每張卡的空閒時間 = (stage_num - 1) * (fp_time + bp_time)

\[\frac{(stageNum - 1)(tf + tp)}{(stageNum - 1)(tf + tp) + microNum(tf + tp)} = \frac{stageNum - 1}{microNum + stageNum - 1} \]

實際應用中 當mico-batch個數大於stageNum的4倍時, 可以忽略bubble_time

  1. 視訊記憶體浪費: 當進行stage3的micro-batch 3時, 還需要儲存前面所有mico-batch的fp中間結果用於bp.
  2. 在每個mini-batch之間無法並行. 因為下一個minibatch需要等當前所有的micro-batch更新完引數

PipeDream(非交錯式1F1B DeepSpeed)

image-20240626191344318

在每個micro-batch fp完成之後立刻優先進行bp. 這樣可以把當前batch的中間變數釋放掉, bp完成後更新本機引數, 但這種方式存在引數更新衝突, 機器1和機器2使用的引數不一樣, 機器1的batch5只用了 batch1反向後更新的引數, 但機器2的使用了batch2的, PipeDream透過多版本引數cache的思想來解決這個問題

image-20240626210948570

為啥worker1需要儲存4個版本引數, 而worker4只需要1個呢? 這裡的版本數和同一個batch fp和bp的間隔決定的. 如果我跑完fp後, 中間有其他batch更新的bp. 那就需要把這些bp結果給快取起來, 不然就會導致fp和bp使用的不是同一份引數. 可以看到worker1的batch5 中間間隔了2,3,4 3次bp, 再加上它本身. 就得儲存4份...這種方法對視訊記憶體極度不友好, 所以有了下面的flush方式

1F1B-flush

image-20240626201932709

對比上面的F-then-B的方式, 1F1B優先bp計算. 每個micro-batch完成後直接釋放掉了對應micro-batch的計算中間值.

只需要儲存1份w, 在固定micro-batch個數後進行一次flush, 同步所有worker的權重使其保持同一個版本.

另外在stage3中 batch1 fp時, 因為batch0已經算完了. 所以可以直接複用batch0的視訊記憶體不用重新分配.

[!NOTE]
這裡有個疑問..越底層的stage需要快取的中間值其實越多, 這種造成儲存不均勻的問題怎麼解決? 透過stage切分不同大小引數的方式麼

1F1B-flush(交錯式, megatron)

image-20240626213836419

這個方案有個新的概念, virtual_pipeline, 方案要求一個小批次中的微批次數量是管道並行大小(流水線中的裝置數量)的整數倍

按之前非交錯式的方法. 一共有8層, worker1如果是1/2層, worker2是3/4層..worker4是7/8層, 每個worker計算連續的層

那麼virtual_pipeline如果是2的話, 會把每個worker進一步拆分, worker1變成了計算1/5層, worker2: 2/6層..類推, 相當於透過把每個worker從單一流水線拆成了virtual_pipeline個流水線.

  • 在之前的1F1B模式裡, 因為每個機器計算是有先後順序的, worker2的通訊接收worker1的fp結果必須等worker1的fp完成.
  • 而在交錯式設計裡, worker2計算的是2/6層, 當他計算2的時候, 可以同步從worker1拿上一個batch的5層結果, 算完2後的理想狀態就是直接算5. 能更好的把通訊隱藏起來.

總結這個方案的優點:

  • 相鄰的計算與通訊操作無依賴關係, 可以加速並行執行
  • 發起通訊操作時,通訊的對端通常已經準備好了要通訊的資料,通訊操作不需要額外的等待時間。

相關文章