LLM並行訓練6-啟用最佳化

SunStriKE發表於2024-07-20

前置知識

Activation

啟用指的是一些在fp時計算得到的臨時tensor, 會用於bp時的計算. 如果能在fp計算後把臨時tensor快取下來就可以加速bp, 缺點在於啟用會佔用大量視訊記憶體. 以一層transformer結構為例分析下各層存在的啟用.

image-20240720164348045

簡單部分的分析這裡忽略. 主要分析下幾個不好理解的計算:

  1. \(QK^T\): 需要快取Q輸出和K輸出 \(2sbh+2sbh\)
  2. \(softmax\): a個head, 每個head進行s次softmax機率分類, 一共有b*s個token, 輸出為fp16, 所以是\(2 * a * s * (s * b) = 2as^2b\)
  3. \(dropout\): 需要快取一個mask陣列標記哪些位置的token需要被反向更新. 所以每個位置只需要bool, 所以是\(a * s * (s * b) = as^2b\)

如果啟用的臨時視訊記憶體全部不釋放, 每個transformerLayer需要佔用的量級為\(sbh(34+5\frac{as}{h})\) (不包含任何並行加速的情況)

image-20240720155156445

選擇性啟用重計算

重計算部分: 以GPT-3為例: a = 96, s = 2048, and h = 12288, 5as/h = 80. 佔了總啟用視訊記憶體的70%左右的主要是attention裡的部分比如: softmax的輸出, dropout的mask, dropout的輸出. 這些部分的啟用都跟矩陣乘法沒關係, 所需要的計算量很小. 把這些在前向計算完之後直接釋放視訊記憶體, 在bp用的時候再重計算的話就能以很小的計算代價換取70%的視訊記憶體儲存.

拆分儲存部分: 而比如像線性層的中間輸出, Q,K,V這些中間tensor. 因為前面有一個和W的矩陣乘法, 重計算代價巨大更適合臨時快取起來用於bp計算複用. 這些啟用透過張量並行和序列並行把他們拆分儲存, 在bp使用的時候透過集合通訊的方式再拉取, 來減少視訊記憶體消耗.

張量並行共計24sbh: 並行後拆分為t份的包括:

  • attention部分: QKV的輸入2sbh, \(QK^T\)結果4sbh, 和V相乘結果裡的2sbh. 共計8sbh
  • MLP部分: 兩個8sbh, 因為TP的先列在行的計算方法全部平分成了t份, 共計16sbh
image-20240720163425716

序列並行共計10sbh, 拆分成t份的包括:

  1. attention部分: layerNorm的輸入輸出4sbh, dropout的mask sbh
  2. MLP部分: layerNorm的輸入輸出4sbh, dropout的mask sbh

序列並行

序列並行指的是在 Transformer 層的非張量並行切分部分,計算在序列維度(s)是獨立的. 所以在這個維度上以切分數量和張量並行數相同的方式進行切分.

為什麼切分數要和TP相等呢? 回憶下TP後如何匯聚計算結果, 在行並行後allReduce, 把每張卡各自計算的結果縱向拼起來還原成完整的輸入. 如果我們想把這部分完整的輸入進行切分儲存, 如果切分數量和TP不一致意味著在$\bar{g} $ 這個地方在allReduce之後還要再進行一次reduceScatter進行分割, 在\(g\)地方再allGather..

而如果切分數和TP相等, 在\(\bar{g}\)這裡可以把allReduce直接省掉, 相當於把allReduce拆分的兩個操作. 在通訊量保持不變的情況下分離了layerNorm的啟用

image-20240720162721282

Zero-R

啟用分割槽 & checkpointing

這裡提到在模型並行的時候activation會存在冗餘副本. 這裡應該就是指的是TP輸入的冗餘副本. 論文裡是說到會把啟用給partition多份..其實我感覺實現方法和megatron裡的序列並行就是一樣的, 標記一下等細看deepspeed程式碼的時候再確認下.

另外還提到個新方法利用記憶體來存啟用checkpoint, 想了下應該是類似下圖的步驟

  1. 在最初始的幾層fp和bp的時間間隔比較遠, 適合在做完fp後memcpyAsync到記憶體.
  2. 靠近loss的後面幾層啟用還存在視訊記憶體裡, 在bp的時候直接用完就釋放了. 快到cpu啟用的部分透過memcpyAsync回來.
  3. 如果訓練和copy啟用能使用同一個stream, 那麼這塊就不需要同步, 按流的順序實行即可
image-20240720194310898

恆定大小視訊記憶體緩衝區

像all-reduce這些集合通訊操作, 在一次通訊一批很大的資料效率很高. 但缺點是會分配大量的臨時視訊記憶體, 這樣會導致視訊記憶體出現較大波動, 在大模型場景會出現問題.

所以zero在這塊設定了一個固定的buffer_size, 超過buffer_size的時候分批次通訊. cpu啟用checkpointing的copy應該也需要相應的方式.

其實在寫flux-gpu的sparse複製的時候也用了類似的方法..分批次複製來避免單次的超大通訊和小資料的碎片通訊.

視訊記憶體碎片解決方法

視訊記憶體碎片產生的原因: 在fp的時候只有一部分啟用儲存下來用於bp, 另外一些需要在bp重算的啟用被釋放了.就會導致一部分視訊記憶體的使用週期很長, 另一部分很短, 從而產生視訊記憶體碎片. 會導致兩個問題: 1. 視訊記憶體allocator查詢滿足大小的視訊記憶體塊效率很低 2. 可能會出現大塊連續視訊記憶體分配不出來.

論文裡說會給activation和grad預分配好連續視訊記憶體塊..emm, 這個做法看著和llm.c裡的實現是一樣的, 其實在大模型裡大部分的w/grad/activation在執行的時候都是定長的, 我們完全可以在第一次執行的時候全部分配好. 在網路計算的時候避免視訊記憶體分配. 如果不使用allocator就不會有碎片問題.

Zero-Offload

主要用來解決模型規模遠大於視訊記憶體規模的問題, 看著灰常似曾相識, 和部署在本地記憶體的引數伺服器很像. 感覺區別在於2點: 1. 訓練是同步的, gpu在cpu更新optimizer_state的時候只能處於等待狀態 2.記憶體裡儲存全量的引數, 不進行多機通訊(多機通訊應該會讓本來就慢的cpu更加雪上加霜吧haha).

計算策略

  1. 保證 CPU 的計算負擔遠遠小於 GPU,從而防止 CPU 成為計算瓶頸;保證 GPU 的記憶體節省最大;(optimizer_state是最佔視訊記憶體的同時, 也是不需要反覆計算的, 一個batch裡只需要存取一次, 不像fp16的w一樣還會參與反向的梯度計算)
  2. 保證 CPU 和 GPU 之間的通訊量最小;(在通訊的時候進行量化和反量化)
image-20240720202450999

排程策略

offload採用的是zero2方案, 也就是fp16 w是分卡儲存的. 考慮使用zero2的很重要的一個原因我猜測是在於多卡可以同時copy w, 而且沒有冗餘資料通訊. 避免pciE頻寬拖後腿.

下圖是單卡的資料流, swap的部分論文裡畫錯了應該是CPU->GPU, 通訊和計算非同步的地方主要有2處:

  1. g offload, 是在gpu bp的時候每計算完一層的g就async copy到記憶體
  2. p swap, cpu更新完一批w, 就分塊進行量化和async copy到視訊記憶體.
image-20240720202948505

Fp16 w到了視訊記憶體裡後就和不同的zero2計算流程完全一樣了.

後面還有一個和推薦模型cpu非同步訓練類似的cpu操作全隱藏訓練模式, 只不過區別是把非同步訓練的n個batch對齊dense改成了固定1個batch.

image-20240720203603036

Zero-Offload++

在第一版offload的時候, 所有的引數都是在cpu計算的, 上面也說到了. 在cpu計算的時候gpu只能空等, 如何在空等的時間視窗把gpu利用起來是一個很大的問題. offload++給了一個很棒的思路. 設定了一個os_w的儲存比例, 以圖示為例, 有40%的os_w存在記憶體裡由cpu更新, 剩下的60%由gpu更新. 步驟如下:

  1. 在bp完靠上層40%的網路後, 把g往記憶體copy
  2. CPU開始逐步計算已經拉下來的g, 更新os_w. 把屬於自己更新的那部分算完
  3. 到達屬於GPU更新的部分後, GPU Scatter 剩下的60% grad到os_w儲存的對應卡上, 更新視訊記憶體裡的os_w
  4. 等cpu算完後量化的fp16_w copy回視訊記憶體和視訊記憶體裡的fp16_w合併, 進行下一輪計算
image-20240720205312147

這裡的比值是人工設定的, 設定原理就是在儘量把視訊記憶體用滿的前提下儘可能的往GPU塞os_w, 塞不下的再放記憶體裡. 這個思路感覺超棒, 待細看程式碼

參考

Megatron-LM論文: https://arxiv.org/pdf/2205.05198

zero-R論文: https://arxiv.org/abs/1910.02054

zero-offload: https://www.usenix.org/system/files/atc21-ren-jie.pdf

zero-offload++部落格: https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp

megatron論文解讀: https://diveblue.notion.site/Megatron-3-Reducing-Activation-Recomputation-in-Large-Transformer-Models-b4d8bfacd33c449383aa9c61123ab578#7c3cc0cb24c444b898c4e435d12bbd4f

相關文章