LLM並行訓練3-資料並行

SunStriKE發表於2024-06-28

前置知識

混合精度訓練

image-20240627193640147

在引數儲存時採取fp32, 開始進行fp/bp時轉成fp16運算, 拿到fp16梯度後再轉回fp32更新引數.

ZeRO對視訊記憶體佔用的估算:

  • 模型狀態: Weights(fp16)、grad(fp16) 和 MasterWeights(fp32 模型引數備份),momentum(fp32)和variance(fp32)。假設模型引數量 \(\phi\) ,則共需要\(2\Phi + 2\Phi + (4\Phi + 4\Phi + 4\Phi) = 4\Phi + 12\Phi = 16\Phi\) 位元組儲存,
  • 剩餘狀態: 除了模型狀態之外的視訊記憶體佔用,包括啟用值(activation)、各種臨時緩衝區(buffer)以及無法使用的視訊記憶體碎片(fragmentation)

Adam

image-20240627210940124

在adam optimizer的計算狀態除了引數, 還有一個\(m_t\)(momentum 梯度均值)和\(v_t\)(variance 梯度未中心化方差)需要儲存, 一般被稱為optimizer state.

AllToAll通訊原語

image-20240628204846386

allToall類似於矩陣轉置. 相當於我們需要先把每個節點裡的資料按照他們要傳遞給哪個節點排好序, 然後根據切分好的順序推給對應的節點. 可以看到如果每個節點的資料量是M, 節點數是N, 最終通訊總量就是M * N

ZeRO

在傳統的訓練方法裡, 每張卡里儲存一份完整的模型狀態, 完成bp後allReduce grad,再更新每張卡里的副本. 這樣子有N張卡就會多出(N-1)份冗餘的引數儲存. 當引數規模急劇增大時這種方法就完全不適合訓練. ZeRO1 主要是將這些冗餘的模型狀態幹掉, 透過增加通訊來解決冗餘引數的問題. ZeRO原理動態圖
image

  • ZeRO1: 只保留一份MasterWeights+momentum+variance.
  • ZeRO2: 在ZeRO1的基礎上去除了grad的冗餘
  • ZeRO3: 在ZeRO2的基礎上去掉了weights的冗餘
image-20240627214641908

訓練流程

以ZeRO3為例. 主要分為5步, 假設使用了4張卡進行訓練:

  1. 每張卡上存1/4的W, OS和grad. 每張卡訓練自己分配到的batch.
  2. fp時, AllGather所有卡上的W,取到全量的W(fp16)進行fp, 完成後只保留自己需要維護的1/4 W, 其他視訊記憶體釋放回池
  3. bp時, AllGather所有卡上的W進行bp, 完成後再拋棄其他卡維護的W
  4. 完成bp後, ReduceScatter所有卡的G, 從其他卡上取到需要需要更新的梯度增量, 然後釋放不是自己維護的G.
  5. 使用自己維護的OS和G來更新W, 不需要通訊.
image-20240628163731199 image-20240628194209187

通訊量分析

定義單卡資料量為\(\phi\)

傳統DP: bp完成後需要對梯度進行一次AllReduce, 一共\(2\phi\)

ZeRO1: 只捨棄了OS, bp時需要AllReduce G(Scatter+Gather 共\(2\phi\)). 另外在使用每張卡各自更新W時, 因為W每張卡都儲存的全量, 需要從儲存OS的卡上把對應更新後的W再拉回來, 所以需要一次Gather(\(\phi\)), 一共需要\(3\phi\)

ZeRO2: 捨棄了OS和G, bp時AllGather G(\(\phi\)), 更新W時從其他卡拉W, 再Gather一次(\(\phi\)), 一共需要\(2\phi\)

ZeRO3: 上面訓練過程分析過, 共需要2次Gather和1次Scatter, 一共需要\(3\phi\)

可以看到ZeRO在通訊量只增加了1.5倍的情況下, 視訊記憶體降了60倍. 效果非常顯著

ZeRO++

ZeRO存在的問題是會在GPU之間產生大量資料傳輸開銷,降低了訓練效率. 主要有兩種情況:

  1. 全域性batch size較小,而 GPU數量多,這導致每個 GPU 上batch size較小,需要頻繁通訊

  2. 在低端叢集上進行訓練,其中跨節點網路頻寬有限,導致高通訊延遲。

ZeRO++主要採用了3部分最佳化: 權重量化 (qwZ), 分層分割儲存 (hpZ), 梯度量化 (qgZ). 對比ZeRO通訊量減少了4倍, 主要的難點都在減小量化帶來的訓練誤差

權重量化

    def _quantize_int8(self, tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        q_range = 2**self.config['num_bits'] - 1
        min_value = tensor.amin(dim=self.config['group_dim'] + 1, keepdim=True)
        max_value = tensor.amax(dim=self.config['group_dim'] + 1, keepdim=True)

        scale = q_range / (max_value - min_value)

        tensor = tensor.sub_(min_value).mul_(scale)
        tensor = tensor_round(tensor_clamp(tensor, 0, q_range)).to(torch.uint8)  #對稱式量化
        return tensor, scale, min_value

量化程式碼在deepspeedcsrc/quantization/quantize.cu cached_quantization 這個kernel裡.

如果採用全域性fp16->int8的量化會導致極大誤差. deepspeed採用了分割槽量化的方法, 把引數分為固定大小的block後, 先根據這個block的max/min計算出scale(量化係數), 在把這個引數傳入量化函式里. 另外在通訊的時候應該也需要每個block對應的係數傳給接收節點用於反量化.

\[量化公式: clip(round(scale * x), -2^{b-1}+1, 2^{b-1}-1) \]

透過這種方式在通訊量減半的同時還能保證精度, 很nice的思路.

image-20240628215923339

分層分割儲存

image-20240628194529492

之前ZeRO的W切分方法是根據卡數均分. 在fp/bp之前進行AllGather拉取, 後來發現在機器間進行Gather通訊是比較嚴重的瓶頸. 所以最後W的切分變成了每個節點記憶體儲全量的W, 節點內根據卡數進行切片. 避免跨節點經過網路卡的通訊, 透過增加視訊記憶體使用的方式解決通訊瓶頸.

視訊記憶體消耗: ZeRO3的單卡視訊記憶體消耗為 $\frac{(2+2+K)*\phi}{N} \(, 這裡每個節點多存了一份W, 如果有\)\alpha$個物理節點, 那麼每張卡使用的視訊記憶體就多了 \(\frac{\alpha * \phi}{N}\)

梯度量化

如果直接在之前zero RingAllReduce的通訊方式上加量化和反量化, 如下圖左, 可以看到需要節點個數次量化/反量化. 而每次量化都是有損的, 這樣會導致無法接受的訓練誤差. 為了解決這個問題zero++使用了一次量化->AllToAll通訊->一次反量化的操作. 而因為直接進行AllToAll通訊量從M(引數量)變成了M*N/Z(N: 節點數, Z:量化壓縮率), 這個通訊量的增長過大. deepspeed設計了2-hpop all-to-all方法來解決通訊問題.

image-20240628200906350

具體圖示流程可以參考Deepspeed的blog動態圖, 文字版步驟:
image

  1. 節點內的卡間張量切片重排. 主要是因為alltoall切分成了兩步, 如果不重排如下圖左. 最後順序會變錯位, 然後進行引數量化

    image-20240628210835122
  2. 節點內alltoall通訊後反量化.先把卡內能合併的梯度加起來. 這裡反量化主要是為了減小梯度累加的精度損失

  3. 再次量化後, 節點間進行allToAll

  4. 拿到通訊結果, 反量化後再次reduce. 得到最終的梯度.

這裡要進行兩次alltoall的原因主要是, 第一次卡間alltoall之後梯度累加可以減少卡數倍的通訊規模. 實際deepspeed在實現的時候還把重分片和量化kernel進行了fuse, 進一步最佳化效能

還有下圖的方法, 在通訊當前層的時候, 透過多流非同步量化下一層要通訊的資料. 避免同步等待的浪費

image-20240628211824538

參考

zero: https://arxiv.org/pdf/1910.02054

混合精度訓練: https://arxiv.org/pdf/1710.03740

zero++: https://arxiv.org/abs/2306.10209

Deepspeed blog: https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md

相關文章