編者按: 怎樣在 10,000 個 H100 GPU 上訓練大模型?如何充分利用每一塊 GPU 的算力?如何在這個複雜的 GPU 網路中高效傳遞資料?當不可避免的硬體故障發生時,又該如何快速恢復訓練進度?我們今天為大家帶來的文章中,作者為我們揭示了應對這些挑戰的關鍵策略。
作者 | Soumith Chintala
編譯 | 嶽揚
我的好友 Francois Fleuret 提出了上述問題。我迅速總結了一些在大規模訓練領域中相當普遍的知識,內容分為三部分。
- 首先,是如何將盡可能大的神經網路和 batch-size 適配到那 10000 張 H100s 上,這個步驟涉及到並行處理和使用節省記憶體的各種技巧。
- 其次,是如何在這些 GPU 之間儘可能高效地傳遞模型狀態資訊(state)。
- 最後,是如何在遇到硬體或軟體故障時,儘可能迅速地恢復系統。
01 如何將盡可能大的神經網路和 batch-size 適配到那 10000 張 H100s 上
1.1 並行策略
- 在資料批次(batches)上進行並行處理(資料並行(data parallel))
- 在神經網路層上進行並行處理(比如,將一層神經網路層分佈到多個 GPU 上進行計算)
- 對神經網路的不同模型層進行分割,以便它們能夠在不同的 GPU 上執行(比如,前 N 層執行在 GPU1 上,第 N+1 層到第 N+10 層執行在 GPU2 上)
持續最佳化並行策略,直到所有 GPU 都能被高效利用,達到最高利用率。
1.2 Checkpointing / Compute vs memorize
- 在執行前向傳播時,需要儲存一些中間結果以便後續計算反向傳播(save\_for\_backward)。然而,當神經網路規模非常大時,為了處理更大的資料批次,更有效的方法是釋放這些中間結果,待到需要計算反向傳播時再重新計算。
- 類似 FSDP 這樣的技術,透過在單個 GPU 上只保留模型的分片來節省記憶體。當需要其他權重時,會從其他 GPU 聚合模型的完整權重。
02 儘可能高效地在 GPU 叢集間傳遞模型狀態資訊
2.1 Communication overlap 策略:
在需要 GPU 間通訊時,應儘可能早地啟動通訊過程:
- 例如,當第 N 層完成反向傳播後,在第 N-1 層還在進行反向傳播計算時,負責第 N 層的所有 GPU 可以同時開始執行梯度全歸約操作。
2.2 探索並利用網路底層拓撲結構:
在多個計算節點間傳遞大量模型狀態資訊(如梯度、最佳化器狀態資訊)是一項複雜的任務。在使用 Sync SGD 時,需要儘可能快地集中傳輸這些狀態資訊。
網路中可能包含多層交換機,並具備 RDMA 能力(可以直接將 GPU 記憶體中的資料複製到網路卡,完全繞過 CPU 記憶體),同時擁有前端和後端網路卡(前端網路卡連線到如 NFS 之類的儲存系統,後端網路卡則將 GPU 連線到叢集中的其他 GPU)。
因此,在執行 all-reduce 或 scatter/gather 等通訊操作時,充分利用這些網路資訊至關重要。例如,透過樹形歸約演算法(tree-reduce),all-reduce 操作的時間複雜度可以降低到O(log(n));同時,網路光纖連線節點間的不同型別光纖對常數因子的影響,對於減少整體延遲時間也是非常重要的。
像 NCCL 這樣的庫能夠智慧地識別底層網路拓撲,並在執行 all-reduce 和其他通訊操作時加以利用。
在這樣的大規模計算中,我們還必須調整交換機和網路卡中的資料包路由演算法,以實現有效的負載均衡。交換機也需要大量的 HBM 記憶體(不僅僅是 GPU 需要),因為當資料包排隊等待時,需要在某個地方排隊而不會被丟棄——這就是交換機級別的 HBM 記憶體。
03 如何在遇到硬體或軟體故障時,儘可能迅速地恢復系統?
故障是不可避免的,涉及GPU、網路卡、電纜等多種硬體。有些故障能夠迅速被發現,而有些則可能因為某個節點沒有按時響應(比如 NCCL 的 all-reduce 操作卡住了)才被察覺。我們開發了多種工具來監控機群的健康狀況,並儘可能快地將故障節點從機群中移除。這可不是一件容易的事。
在這種規模下,記憶體位隨機翻轉導致的隱性資料損壞機率增加,可能導致訓練 loss 值異常升高。雖然這種問題在小規模系統中很少見,但在大規模系統中則可能頻繁發生。在軟體層面提前檢測這種問題非常困難。一些硬體裝置配備了內建校驗和的電路,可以在計算後進行校驗 —— 這樣,一旦發生位翻轉,硬體就能觸發中斷。但 H100 和之前的 NVIDIA GPU 都不具備這一功能。
為了應對這些故障,我們需要儘可能頻繁且迅速地儲存模型狀態資訊;一旦發生故障,我們也要能夠迅速恢復並繼續訓練。通常,我們會迅速將模型狀態資訊另存到 CPU 記憶體的一個獨立執行緒中,並在後臺將資料從 CPU 記憶體寫入到磁碟或遠端儲存系統。我們還以分片的形式儲存模型狀態資訊(利用了 torch.distributed 的 checkpointing 功能),也就是說,不是每個 GPU 都需要儲存完整的模型權重;每個 GPU 只需儲存一部分權重 —— 其餘部分可以透過其他 GPU 的分片 checkpoints 來恢復。
Thanks for reading!
Hope you have enjoyed and learned new things from this blog!
About the authors
Soumith Chintala
Cofounded and lead @PyTorch at Meta. Also dabble in robotics at NYU. AI is delicious when it is accessible and open-source.
END
本期互動內容 🍻
❓還記得你第一次配置分散式訓練環境時的經歷嗎?有什麼想對新手說的建議?
原文連結:
https://soumith.ch/blog/2024-10-02-training-10k-scale.md.html