[論文翻譯] 分散式訓練 Parameter sharding 之 ZeRO

羅西的思考發表於2022-01-11

[論文翻譯] 分散式訓練 Parameter sharding 之 ZeRO

0x00 摘要

Parameter sharding 就是把模型引數等切分到各個GPU之上,以此達到使用較少GPU實現大規模模型訓練的目的。本系列會以 Google,微軟和Facebook的論文,部落格以及程式碼來對parameter sharding 進行分析,大約有 5~6篇文章。

本文以 ZeRO: Memory Optimizations Toward Training Trillion Parameter ModelsDeepSpeed: Extreme-scale model training for everyone為主來進行分析,這是微軟開發的一個可以高效利用視訊記憶體的優化器,其會將模型狀態量(優化器狀態,梯度和模型引數)分佈在多個並行 GPU 之上,目的是在不使用模型並行的情況下對讓數十億引數模型進行訓練。

ZeRO是ZeRO-DP和ZeRO-R兩種方法的組合。ZeRO-DP是一種增強資料並行機制,它使用動態通訊策略來將優化器狀態、梯度和引數進行分割槽,以最小化通訊量和避免模型狀態的冗餘。ZeRO-R則使用分割槽啟用重計算、恆定大小緩衝區和動態記憶體碎片整理機制來優化剩餘狀態的記憶體消耗。

本文並不會逐字分析,而是選擇了一些重點,並且爭取加入筆者自己的理解。

本系列其他文章如下:

[原始碼解析] PyTorch 分散式之 ZeroRedundancyOptimizer

0x01 綜述

這部分主要翻譯自 DeepSpeed: Extreme-scale model training for everyone

1.1 挑戰

首先,我們需要了解訓練巨大模型所帶來的的視訊記憶體和計算效率的挑戰。

1.1.1 視訊記憶體效率

訓練萬億引數模型所需的視訊記憶體遠遠超出了單張 GPU 的視訊記憶體大小。比如,在使用 Adam 優化器進行混合精度訓練時需要約 16TB 的視訊記憶體來儲存模型狀態量(引數、梯度和優化器狀態量)。僅僅為了儲存模型狀態,就需要 400 張英偉達 A100 GPU(每張40 GB 的視訊記憶體)。

啟用函式也需要佔據額外的視訊記憶體,其隨批量大小(batch size)而增加。batch size設定為1的情況下,訓練萬億引數模型就會使用超過 1 TB 視訊記憶體來儲存啟用。人們也嘗試用 checkpoint 來處理啟用視訊記憶體,就是用計算來換視訊記憶體,這可以將該視訊記憶體減少到大約20 GB,但是對於訓練而言,這個視訊記憶體需求仍然過高。

所以,必須在多個 GPU 裝置之間有效地劃分模型狀態和啟用視訊記憶體,才能讓這種大模型在不耗盡視訊記憶體的情況下進行訓練。

1.1.2 計算效率

基於 OpenAI 的研究 law of scaling來估算,端到端訓練一個萬億引數的模型大約需要 5000 Zflops(即 5 後面帶有 24 個零)。訓練這樣一個模型需要 4000 張 A100 以 50% 的計算效率執行大約 100 天。

儘管大型超級計算 GPU 叢集可以擁有超過 4000 個 GPU,但是由於 batch size 的限制,要在這種規模上實現高計算效率仍然很具有挑戰性。計算效率隨著計算時間對通訊時間的比例的增加而增加。該比例與 batch size成正比。但是,模型可以訓練的 batch size有一個上限,如果超過這個上限,則收斂情況會迅速惡化。

1.2 權衡

我們接下來看看資料並行、模型並行和流水線並行之間的權衡。

1.2.1 資料並行

資料並行是深度學習中的十分常見的技術。在資料並行中,每批輸入的訓練資料都在資料並行的 worker 之間進行平分。反向傳播之後,我們需要進行通訊來規約梯度,以保證優化器在各個 worker 上可以得到相同的更新。資料並行性具有幾個明顯的優勢,包括計算效率高和工作量小。但是,資料並行的 batch size 會隨 worker 數量提高,而我們難以在不影響收斂性的情況下無限增加 batch szie。

  • 視訊記憶體效率:資料並行會在所有 worker 之間複製模型和優化器,因此視訊記憶體效率不高。
  • 計算效率:隨著並行度的提高,每個 worker 執行的計算量是恆定的。資料並行可以在小規模上實現近乎線性擴充套件。但是,因為在 worker 之間規約梯度的通訊成本跟模型大小成正相關,所以當模型很大或通訊頻寬很低時,計算效率會受到限制。梯度累積是一種常見的用來均攤通訊成本的策略,它可以增加batch size,在本地使用 micro-batch 進行多次正向和反向傳播,在進行優化器更新之前再規約梯度,從而分攤通訊成本。

1.2.2 模型並行

模型並行是另一大類技術。它可以在多個 worker 之間劃分模型的各個層。就其本質而言,模型並行的計算和通訊因模型結構而異,因此需要很大的工作量來實現。DeepSpeed 利用了英偉達的 Megatron-LM 來構建基於 Transformer 的大規模模型並行語言模型。模型並行會根據 worker 數量成比例地減少視訊記憶體使用,這是這三種並行模式中視訊記憶體效率最高的。但是其代價是計算效率最低。

  • 視訊記憶體效率:模型並行的視訊記憶體使用量可以根據 worker 數量成比例地減少。至關重要的是,這是減少單個網路層的啟用視訊記憶體的唯一方法。DeepSpeed 通過在模型並行 worker 之間劃分啟用視訊記憶體來進一步提高視訊記憶體效率。
  • 計算效率:因為每次前向和反向傳播中都需要額外通訊來傳遞啟用,模型並行的計算效率很低。模型並行需要高通訊頻寬,並且不能很好地擴充套件到通訊頻寬受限的單個節點之外。此外,每個模型並行worker 都會減少每個通訊階段之間執行的計算量,從而影響計算效率。模型並行性通常與資料並行性結合使用,以便在記憶體和計算效率之間進行權衡。

1.2.3 流水線並行

流水線並行將模型的各層劃分為可以並行處理的階段。當一個階段完成一個 micro-batch 的正向傳播時,啟用記憶體將被髮送給流水線的下一個階段。類似地,當下一階段完成反向傳播時,將通過流水線把梯度反向傳遞回來。為了確保流水線的各個階段能平行計算,必須同時計算多個 micro-batch 。目前已經有了幾種用於權衡記憶體和計算效率以及收斂行為的實現,例如PipeDream。DeepSpeed 通過梯度累積來實現並行,並在相同的總 batch size下,可以達到與傳統資料並行和模型並行訓練收斂情況相同。

  • 視訊記憶體效率:流水線並行減少的視訊記憶體與流水線的階段數成正比,這允許使模型的大小可以隨 worker 的數量線性擴充套件。但是,流水線並行不會減少每一層的啟用視訊記憶體佔用量。此外,每個 worker 必須儲存同時執行的各個 micro-batch 的啟用值。這導致流水線第一階段的啟用記憶體與單個 mirco batch 的總啟用記憶體大致相同。
  • 計算效率:因為流水線的通訊量只和階段邊界的各層的啟用值大小成正比,所以流水線並行的通訊量最低。但是,它不能無限擴充套件。像模型並行一樣,增加流水線大小會減少每個流水線階段的計算量,這會降低計算與通訊的比率。為了實現良好的計算效率,流水線並行還要求其每個階段都進行完美的計算負載均衡。

此外,流水線並行性會在每個 batch 的開始和結束時因為需要重新填充或排空流水線而產生 bubble overhead。使用流水線階段數的 4 倍或 8 倍的梯度累積步驟(以及 batch 大小)進行訓練,相較於只有一個流水線階段分別達到了 81% 和 90% 的擴充套件性。

1.3 通過 3D 並行實現記憶體和計算效率

資料、模型和管道並行在提高記憶體和計算效率方面各自發揮特定作用,所以DeepSpeed 結合了這三項強大的技術,可以訓練數萬億規模的模型並擴充套件到數千個 GPU。這三者的共生並行同時解決了訓練萬億引數模型的兩個基本挑戰:視訊記憶體效率計算效率,讓深度學習訓練的規模遠遠超出了單獨使用每種策略可以企及的高度。因此,DeepSpeed 可以在視訊記憶體中放下巨大的模型,而不會犧牲速度。

視訊記憶體效率:模型的各層首先被劃分到不同的流水線階段,其次,把每個階段的層通過模型並行進一步進行劃分。這種 2D 組合同時減少了模型、優化器和啟用所消耗的記憶體。然而,我們無法在不引入通訊開銷的情況下無限地劃分模型,而通訊開銷勢必會限制計算效率。

計算效率:為了在不犧牲計算效率的情況下將 worker 數量擴充套件至超出模型和流水線並行能支援的規模,我們使用了 ZeRO 支援的資料並行功能(ZeRO-DP)。ZeRO-DP 不僅可以通過劃分優化器狀態量進一步提高視訊記憶體利用效率,而且還可以通過利用基於通訊拓撲的對映關係,以最小的通訊開銷擴充套件到任意數量的 GPU。

圖 1:32 個worker的 3D 並行示例。神經網路的層被分為四個流水線階段。每個流水線階段內的層在四個模型並行worker之間進一步劃分。最後,每個流水線階段都跨兩個資料並行例項進行復制,ZeRO 在這兩個資料並行副本之間對優化器狀態進行分割槽。

下圖展示了通訊拓撲感知的 3D 對映:通過利用兩個關鍵的架構屬性,我們將 3D 並行中的每個維度仔細地對映到 worker 之上,從而最大化計算效率。

  1. 優化節點內和節點間的通訊頻寬:模型並行具有三種策略中最大的通訊開銷,因此我們優先考慮將模型並行 worker 組放置在節點內以利用更大的節點內頻寬。這裡我們基於英偉達 Megatron-LM 進行了張量切分式的模型並行。當模型並行組無法佔滿節點內的所有 worker 時,我們選擇將資料並行組放置在同一個節點內。否則它們將跨節點進行資料並行。流水線並行具有最低的通訊量,因此我們可以跨節點進行排程流水線的各個階段,而不受到通訊頻寬的限制。
  2. 通過並行通訊增大頻寬:每個資料並行組需要傳遞的梯度量隨著流水線和模型並行的規模而線性減小,因此3D總通訊量少於單純使用資料並行的通訊量。此外,每個資料並行組會在本地的一小部分 worker 內部獨立進行通訊,組間通訊可以相互並行。因此,通過減少通訊量和增加區域性性與並行性,我們可以有效的擴大資料並行通訊有效頻寬。

圖 2:圖 1 中的 worker 到八個節點(每個節點有四個 GPU)的系統上的 GPU 的對映。同一顏色的 GPU 在同一節點上。

1.4 3D 並行如何利用每種並行性

一個萬億引數模型可以使用 8 路模型並行、64 路管道並行和 8 路資料並行在 4,096 個 NVIDIA A100 GPU 上進行擴充套件。

  • 通過結合模型並行和流水線並行,3D 並行在多個節點之上實現了出色的記憶體效率和高效計算效率。模型並行性提高了節點內的啟用和模型狀態的儲存效率,流水線並行(相較於僅使用模型並行)則可以在不犧牲計算效率的情況下,跨節點高效儲存模型狀態。

  • 通過將模型並行性與流水線並行性相結合,即使在非常小的批量大小下,流水線並行性也能以最小的氣泡開銷實現高計算效率。在 8 路模型並行下,每個模型每使用 1 個微批次,將導致每個 GPU 的有效微批次大小為1/8。因此,流水線並行可以使用 8 倍流水線並行度的梯度累積步驟來實現 90% 的計算效率,並且每個 GPU 的總累積batch size 僅為 1。當與資料並行相結合時,這讓 4096 張 GPU 上的總有效 batch size 為 4096,並依然可以達到 90% 的流水線效率。

但是資料並行性會帶來怎樣的計算效率呢?資料並行性是否需要每個 GPU 擁有大批量才能保持高效?

模型並行可以將每張GPU上的有效 batch 大小減小到小於 1。這允許流水線並行即使在小 batch 下依然可以有效隱藏流水線 bubble 開銷。請注意,通過使用跨節點流水線並行性,我們就可以讓流水線每個階段的資料並行節點之間的通訊獨立發生並與其他流水線階段並行。實際上,在高階 GPU 叢集中常見的全連線的網路拓撲中,這對於資料並行訓練的可用有效通訊頻寬具有重要意義。由於流水線階段中的每個節點都可以與其對應的資料並行節點並行通訊,因此有效的通訊頻寬與流水線階段數量成正比。設定 64 個流水線並行級之後,有效頻寬將是往返於單個節點的頻寬的 64 倍。憑藉如此大的流水線並行有效頻寬,即使在計算與通訊比率非常低的小批量下,資料並行性也能有效擴充套件。

0x02 引論

我們接下來以 ZeRO: Memory Optimizations Toward Training Trillion Parameter Models 來進行學習。

2.1 原文摘要

大規模深度學習模型可以顯著提高accuracy,但訓練數十億到數萬億的引數是一項挑戰,因為單個GPU無法容納如此大的儲存模型和狀態量。現有的解決方案比如跨GPU的資料和模型並行存在了很大侷限性:這些方案雖然獲得了計算、通訊和開發效率,但都是在各種因素之間權衡,而且有一個最基本的問題:模型只能位於有限的裝置記憶體中。

論文作者開發了一種新的解決方案,使用零冗餘優化器(Zero)來優化記憶體,這樣可以極大地提高了訓練速度,同時增加了可以有效訓練的模型大小。ZeRO在尋求資料並行和模型並行的一個適當中間點,其希望消除資料和模型並行訓練中的記憶體冗餘,同時保持了較低的通訊量和較高的計算粒度,使我們能夠以持續的高效率按裝置數量比例調整模型大小。因此,ZeRO 可以獲得資料並行性和模型並行性的優點。可以用同樣記憶體來來執行更大的模型,可以使用資料並行方式來訓練那些以前只能使用模型並行進行訓練的模型。

2.2 原文引論

常見的資料並行性(DP)並不會減少每個裝置的記憶體,而其他現有解決方案,如管道並行(PP)、模型並行(MP)、CPU解除安裝(CPU-Offloading)等這些都是在功能性、可用性以及記憶體和計算/通訊效率之間進行權衡。在訓練大型模型的各種現有解決方案中,MP可能是最有前途的,然而,MP無法擴充套件到更大的尺寸。MP 會把模型進行垂直分割,將每一層中的計算和引數劃分到多個裝置上,這需要在每一層之間進行大量通訊。因此,它們在GPU間通訊頻寬較高的單個節點內工作良好,但在單個節點之外,效率會迅速下降,因此無法有效的擴充套件到單個節點之外。

那麼,我們如何克服現有解決方案的侷限性,更有效地訓練大型模型呢?為了回答這個問題,我們首先分析了現有系統在模型訓練方面的全部記憶體消耗,並將其分為兩部分:

  • 對於大型模型來說,大部分記憶體被模型狀態佔用,其中包括優化器狀態(如Adam中的動量和方差)、梯度和引數。
  • 剩餘的記憶體被:啟用、臨時緩衝區和不可用的碎片佔據,我們統稱為剩餘狀態。

因此,我們開發了Zero Redundancy Optimizer,在獲得高計算和通訊效率的同時,可以優化這兩個方面的記憶體效率。

2.2.1 優化模型狀態

模型狀態通常在訓練過程中消耗最大的記憶體量,但是現有的方法,如DP和MP並不能提供令人滿意的解決方案。

DP具有良好的計算/通訊效率,但記憶體效率較差,而MP的計算/通訊效率較差。更具體地說,DP在所有資料並行程式中複製整個模型狀態,導致冗餘記憶體消耗;雖然MP對這些狀態進行分割槽以獲得較高的記憶體效率,但往往會導致過於細粒度的計算和昂貴的通訊,從而降低了擴充套件效率。此外,這些方法靜態地維護整個訓練過程中所需的所有模型狀態,即使在訓練過程中並非始終需要所有模型狀態。

基於這些觀察結果,我們開發了ZeRO-DP,ZeRO-DP在保持DP的計算/通訊效率基礎之上,同時實現了MP的記憶體效率。ZeRO-DP通過對模型狀態進行分割槽而不是複製來消除資料並行程式中的記憶體狀態冗餘,這樣每個GPU之上的記憶體消耗將會和資料並行度成反比,並通過在訓練期間使用動態通訊排程來保留同DP基本一致的計算粒度和通訊量,這樣可以保持計算/通訊效率。

在模型訓練期間,大部分記憶體被以下三種情況之一消耗:

  • 啟用。
  • OGP狀態,即由優化器狀態(O),引數梯度(G)和引數本身(P)組成的張量。
  • 臨時緩衝區。

可能有人會問為什麼不考慮輸入資料的記憶體消耗,其實,輸入資料所佔用的視訊記憶體其實並不大,這是因為使用者基本使用迭代器讀取資料,這意味著資料並不是一次性全部讀入視訊記憶體,因此每次輸入所佔用的視訊記憶體與整個網路引數相比其實是微不足道的。

ZeRO DP有三個主要的優化階段(如下圖1所示),它們對應於優化器狀態、梯度和引數的分割槽。當逐步啟用時:

1)優化器狀態分割槽(Pos):記憶體減少4倍,通訊量與DP相同,此階段也被稱為 ZeRO-OS。

2)新增梯度分割槽(Pos+g)優化:記憶體減少8倍,通訊量與DP相同;

3)新增引數分割槽(Pos+g+p)優化:記憶體減少與DP的並行度成線性關係。模型記憶體被平均分配到每個GPU之上,每個gpu上的記憶體消耗與資料並行度成反比,但是通訊量只是適度增加。例如,跨64個GPU(Nd=64)拆分將產生64倍的記憶體縮減。通訊量適度增加了50%。

記憶體消耗具體可以參見下圖:

圖1:ZeRO-DP優化的三個階段之中每個裝置記憶體消耗比較。ψ表示模型大小(引數數量),K表示優化器狀態的記憶體乘數,Nd表示DP並行度,即Nd個GPU。在本例中,我們假設基於Adam優化器的混合精度訓練,模型大小為ψ=7.5B,DP為Nd=64,K=12。

2.2.2 優化殘餘狀態記憶體

在使用 ZeRO-DP 優化模型狀態對應的記憶體之後,殘餘記憶體(Residual State Memory)成為次要記憶體瓶頸,剩餘記憶體包括:啟用、臨時緩衝區和不可用記憶體片段。我們開發了ZeRO-R來分別優化這三個因素所消耗的剩餘記憶體。

  • 對於啟用(從前向傳播結果之中儲存,用來支援後向傳播),我們注意到優化檢查點會有幫助,但是對於大型模型不夠用。因此,ZeRO-R通過在現有MP方案中識別和刪除啟用副本來優化啟用記憶體。它還可以在適當的時候將啟用解除安裝到CPU。

  • ZeRO-R為臨時緩衝區定義了適當的大小,以實現記憶體和計算效率的平衡。

  • 我們觀察到在訓練中,由於不同張量生命週期的變化而會導致一些記憶體碎片。由於這些碎片的存在,會導致即便即使有足夠的可用記憶體,也會因為缺少連續記憶體而使得記憶體分配失敗。ZeRO-R根據張量的不同生命週期來主動管理記憶體,防止記憶體碎片。

ZeRO-DP和ZeRO-R結合在一起形成了一個強大的DL訓練記憶體優化系統,我們統稱為ZeRO。

2.2.3 ZeRO和MP

因為ZeRO消除了DP中的記憶體效率不足,所以很自然地會問:我們還需要MP嗎?什麼時候需要?ZeRO如何與MP一起工作?

使用ZeRO之後,MP對於大型模型就不太有吸引力了。ZeRO-DP在減少每裝置記憶體佔用方面至少與MP一樣有效,或者在MP無法均勻劃分模型時更有效。它還具有相當或更好的縮放效率。此外,資料並行非常容易使用,因此它廣泛適用於不同的工作負載,而如今的MP方法通常需要模型開發人員的一些額外工作來修改其模型,比如現有的工作(如Megatron-LM)只支援有限的操作和模型集。

儘管如此,仍然有一些情況下我們希望利用MP:

i)當與ZeRO-R一起使用時,MP可以減少大型模型的啟用記憶體佔用。

ii)對於啟用記憶體不是問題的較小模型。當單獨使用DP時,可能會因為聚合batch size太大而無法實現良好的收斂性,這時候MP也有好處。在這種情況下,可以將ZeRO與MP結合起來,使模型具有可接受的聚合batch size。

0x03 相關工作

3.1 資料,模型和流水線並行

並行化是大型模型訓練的關鍵策略。對於可以塞進裝置記憶體的模型,資料並行(DP)用於將訓練擴充套件到多個裝置。在DP中,模型引數複製到每個裝置上。在每個步驟中,一個小批量被均勻地分發到所有資料並行程式中,這樣每個程式都會對不同的資料樣本子集執行正向和反向傳播,並使用程式間的平均梯度來區域性更新模型。

當一個模型不適合裝置記憶體時,模型並行性(MP)和流水線並行性(PP)分別以垂直和水平方式在程式之間分割模型。

PP在層之間水平拆分一個模型,在不同裝置上執行不同的分割槽,並使用微批處理隱藏管道氣泡。由於水平拆分和micro-batching,所以某些功能(如tied-weight和batch-normalization)難以實現。

流行的PP實現(如G-pipe)同時對模型引數和總啟用進行分割槽,但需要與管道分割槽數量成比例的batch size來隱藏管道氣泡。大batch size可能會影響收斂速度,PP同時也需要大量記憶體來儲存啟用。

PipeDream是PP的另一種實現,其保留了過時引數的多個副本,以隱藏管道氣泡,而不會顯著增加batch size,從而可以降低記憶體效率。此外,該實現不等同於標準DL訓練,並且對訓練收斂有影響。

相比之下,ZeRO獲得了與PP相同或更好的記憶體效率,而不會有PP帶來的功能、效能和與收斂的限制。

3.2 非並行方面的工作

原小標題為Non-parallelism based approach to reduce memory。

除了MP和PP之外,還有很多旨在減少DL訓練記憶體開銷的工作。

3.2.1 減少啟用記憶體

目前,有很多工作集中在減少啟用的記憶體佔用上,包括壓縮、啟用檢查點或實時分析。這些努力是互補的,可以與ZeRO一起工作。事實上,ZeRO-R中的啟用記憶體減少完全可以和啟用檢查點並行工作。

3.2.2 CPU Offload

也有一些工作利用計算節點的異構性,分別通過演算法設計或虛擬化記憶體將模型狀態轉移到CPU記憶體。但是這導致有50%的時間被浪費在GPU-CPU-GPU傳輸。ZeRO的不同之處在於,它顯著降低了記憶體消耗,而無需將模型狀態儲存到CPU記憶體中。在極少數情況下,ZeRO-R可能只針對非常大的模型才解除安裝啟用檢查點,以提高效能。

3.2.3 記憶體高效(Efficient)優化器

另一些工作是通過獲取模型引數和梯度的粗粒度統計資料來減少自適應優化方法的記憶體消耗,這可能會對模型收斂保證產生影響。ZeRO與這些工作是正交的,它的優化不會改變模型優化方法或影響模型收斂,但會有效地減少每個裝置的優化器狀態和梯度的記憶體佔用。

3.3 訓練優化器

對於大型模型,自適應優化(Adaptive)方法對於達到SOTA效能和精度至關重要。與SGD相比,它以顯著的記憶體佔用為代價,維護每個模型引數和梯度的細粒度一階和二階統計資訊。ZeRO可以將這些優化器的記憶體佔用減少幾個數量級,使這些複雜的優化方法對於在具有適度裝置記憶體的硬體上訓練大型模型非常實用。它還讓人們可以開發和使用更復雜、記憶體消耗更大、收斂性更好的優化器。

0x04 模型記憶體都去哪裡了?

讓我們退一步來研究一下當前訓練系統的記憶體消耗。例如,一個1.5B引數的GPT-2模型需要3GB記憶體用於16位精度的權重(或引數),但是人們卻不能使用Tensorflow或PyTorch在一個32GB記憶體的GPU上進行訓練。人們可能想知道所有的記憶體都去了哪裡。在模型訓練期間,大部分記憶體被模型狀態消耗,即由optimizer狀態、梯度和引數組成的張量。除了這些模型狀態,其餘的記憶體被啟用、臨時緩衝區和碎片化記憶體消耗,我們稱之為剩餘狀態。我們將從這兩個方面詳細研究記憶體消耗。

4.1模型狀態:優化器狀態,梯度和引數

原小標題為:Model States: Optimizer States, Gradients and Parameters

大多數裝置記憶體在訓練期間由模型狀態消耗。例如,用Adam,DL訓練中最流行的優化器之一作為例子。Adam需要儲存兩個優化器狀態,i)時間平均動量(time averaged momentum)和ii)梯度方差(variance of the gradients)來計算更新。因此,要使用ADAM訓練模型,必須有足夠的記憶體來儲存梯度動量和方差的副本。此外,也需要有足夠的記憶體來儲存梯度和權重本身。在這三種型別的引數相關張量中,優化器狀態通常消耗最多的記憶體,特別是在應用混合精度訓練時。

4.1.1 混合精度訓練

在當前一代NVIDIA GPU上訓練大型模型的最先進方法是通過混合精度(fp16/32)訓練,在這個方法中,引數和啟用儲存為fp16,從而能夠在這些GPU上使用高吞吐的張量核心單元。在混合精度訓練期間,優化器使用fp16權重和啟用執行正向和反向傳播。但是,為了在反向傳播結束時有效地計算和應用權重更新,混合精度優化器必須保留引數的fp32副本以及所有其他優化器狀態的fp32副本。

讓我們以Adam優化器為例。使用Adam對帶有ψ個引數的模型進行混合精度訓練需要足夠的記憶體來儲存引數和梯度的fp16副本。其記憶體需求分別為2ψ和2ψ位元組。此外,它還需要儲存優化器狀態,引數動量和方差的fp32副本,其記憶體需求分別為4ψ,4ψ和4ψ位元組。

讓我們使用K來表示優化器狀態的記憶體乘數(multiplier),即儲存它們所需的額外記憶體是Kψ位元組。混合精度Adam的K=12。總的來說,這將產生2ψ+2ψ+Kψ=16ψ位元組的記憶體需求。對於具有15億個引數的GPT-2這樣的模型,這至少需要24GB的記憶體,遠遠高於單獨儲存fp16引數所需的3GB記憶體。

4.2 剩餘記憶體佔用

原標題為 Residual Memory Consumption

4.2.1 啟用

在訓練期間,啟用會佔用大量的記憶體。作為一個具體的例子,1.5B引數的GPT-2模型以1K的序列長度和32的batch size 進行訓練,需要大約60GB的記憶體。啟用檢查點(或啟用重新計算)是一種常用的方法,可將啟用記憶體減少到總啟用的平方根,但需花費33%的重新計算開銷。這將使此模型的啟用記憶體消耗減少到約 8 GB。

儘管有顯著的減少,但對於更大的模型,即便使用啟用檢查點,啟用記憶體也會變得相當大。例如,一個具有1000億個引數的類GPT模型,對於32大小的batch size,則需要大約60 GB的記憶體,即使使用啟用檢查點也是如此。

4.2.2 臨時緩衝區

對於大型模型,用於儲存中間結果的臨時緩衝區會消耗大量記憶體。有些操作,比如gradient all-reduce或者 gradient norm computation 會傾向於將所有梯度融合到單個平坦緩衝區中,以此來執行一個統一操作,這樣可以提高吞吐量。例如,所有裝置的頻寬都會隨著訊息的增大而降低。雖然梯度本身通常儲存為fp16張量,但融合緩衝區可以是fp32張量(具體取決於操作型別)。當模型較大時,這些臨時緩衝區大小是非常重要的。例如,對於引數為1.5B的模型,扁平fp32緩衝區需要6GB記憶體。

4.2.3 記憶體碎片

到目前為止,我們已經討論了訓練期間的實際記憶體消耗。此外,即使有足夠的可用記憶體,也可能耗盡可用記憶體。記憶體碎片就可能導致這種情況。如果沒有足夠的連續記憶體來滿足對記憶體的請求,即使總可用記憶體大於請求的記憶體,對記憶體的請求也會失敗。在訓練非常大的模型時,我們觀察到明顯的記憶體碎片,這會導致記憶體不足問題,在某些極端情況下,即使超過30%的記憶體仍然可用,依然無法分配記憶體。

0x05 ZeRO: 感悟和概述

ZeRO有兩組優化:i)ZeRO DP旨在減少模型狀態的記憶體佔用,ii)ZeRO-R旨在減少剩餘記憶體消耗。

5.1 感悟和概述: ZeRO-DP

ZeRO powered DP 基於三個關鍵感悟:

  • DP比MP具有更好的擴充套件效率,因為MP降低了計算粒度,同時也增加了通訊開銷。超過某一點後,較低的計算粒度會降低每個GPU的效率,而增加的通訊開銷會隱藏跨GPU的可伸縮性,特別是在跨越節點邊界時。相反,DP具有更高的計算粒度和更低的通訊量,從而帶來更高的效率。

  • DP記憶體效率低下,因為模型狀態在所有資料並行程式中冗餘儲存。相反,MP對模型狀態進行分割槽以獲得記憶體效率。

  • DP和MP都保留了整個訓練過程中所需的模型所有狀態,但並非所有情況都需要。例如,每個層的引數僅在層的正向傳播和反向傳播期間需要。

基於這些感悟,ZeRO DP 在保留DP的訓練效率的同時,也實現了MP的記憶體效率。ZeRO DP對模型狀態進行分割槽,而不是複製它們,並使用動態通訊計劃,該計劃利用模型狀態固有的時間特性,同時最小化通訊量。通過這樣做,ZeRO-DP隨著DP度的增加線性地減少了模型的每裝置記憶體佔用,同時保持通訊量接近預設DP,這樣就保持了效率。

5.2 感悟和概述: ZeRO-R

5.2.1 降低啟用記憶體

兩個關鍵感悟是:

  • MP對模型狀態進行分割槽,但通常需要複製啟用記憶體。例如,如果我們垂直分割一個線性層的引數並跨兩個GPU平行計算它們,那麼每個GPU都需要整個啟用來計算其分割槽。

  • 對於GPT-2或更大的模型,算術強度(每次迭代的計算量與每次迭代的啟用檢查點量之比)非常大(≥ 10K),並隨著隱藏維度增加而線性增加,從而可以隱藏啟用檢查點的資料移動成本,即使在頻寬較低的情況下也是如此。

ZeRO通過跨GPU劃分啟用檢查點來消除MP中的記憶體冗餘,並使用allgather按需重建它們。啟用記憶體的減少與MP程度成比例。對於非常大的模型,ZeRO甚至可以選擇將啟用分割槽移動到CPU記憶體中,同時由於這些模型中的運算強度很大,因此仍然可以實現良好的效率。

5.2.2 管理臨時緩衝區

ZeRO-R使用恆定大小的緩衝區來避免臨時緩衝區隨著模型大小的增加而崩潰,同時使它們足夠大以保持效率。

5.2.3 管理記憶體碎片

記憶體碎片是短生命週期記憶體物件和長生命週期記憶體物件交錯分配的結果。在正向傳播期間,啟用檢查點的壽命很長,但重新計算的啟用壽命很短。同樣,在反向計算中,啟用梯度的壽命很短,而引數梯度的壽命很長。基於這一認識,ZeRO通過將啟用檢查點和梯度移動到預先分配的連續記憶體緩衝區來執行動態記憶體碎片整理。這不僅提高了記憶體可用性,還通過減少記憶體分配器查詢可用連續記憶體所需的時間來提高效率。

0x06 深入瞭解 ZeRO-DP

雖然現有的DP方法在每個裝置上覆制模型狀態並引入顯著的記憶體開銷,但ZeRO DP通過跨資料並行的程式對它們(優化器狀態、梯度和引數)進行分割槽來消除這種記憶體冗餘。圖1量化並視覺化了有無ZeRO-DP的記憶體需求。該圖顯示了(1)優化器狀態(2)梯度和(3)引數累積冗餘 這三種引數在分割槽後的記憶體佔用。我們將其稱為ZeRO DP的三個優化階段:Pos、Pg和Pp,我們將在下面詳細說明。這裡把圖一再次貼出來。

6.1 Pos : 優化器狀態分割槽

對於一個\(N_d\)並行度的DP來說,我們將優化器狀態分組到\(N_d\)個相等的分割槽中,這樣第i個資料並行程式只更新與第i個分割槽對應的優化器狀態。因此,每個資料並行過程只需要儲存和更新總優化器狀態 的$ \frac{1}{N_d}\(,然後只更新\) \frac{1}{N_d}$個引數。在每個訓練步驟結束時,我們會執行一個跨資料並行程式的all-gather操作,以獲得跨所有資料並行程式的完全更新的引數。

如圖1所示的具體示例,7.5 B引數模型,使用64路DP(\(N_d\)=64),其Pos需要31.4GB記憶體。而使用標準DP則需要120 GB記憶體。此外,當\(N_d\)較大時,模型狀態的記憶體需求從4ψ+12ψ=16ψ位元組減少到4ψ+\(\frac{12ψ}{N_d}\)位元組≈ 4ψ位元組,導致4x倍數的減少。

6.2 Pg: 梯度分割槽

由於每個資料並行程式只負責更新其相應的引數分割槽,因此,每個節點僅僅對自己負責的那部分引數的梯度進行規約。在歸併之後,每個節點只需要自己引數分割槽對應的梯度,對於其他的梯度不再需要,所以它們的記憶體可以被釋放。這將梯度的記憶體佔用從2ψ位元組縮減到 \(\frac{2ψ}{N_d}\)

實際上,這是一種 Reduce-Scatter操作,不同引數的梯度被減少到不同的程式之中。為了提高效率,我們使用了bucketization策略,其中我們將對應於特定分割槽的所有梯度bucketization,並立即對整個bucket執行規約。在我們的例子中,我們在分割槽邊界執行一個reduce而不是 all-reduce,以減少記憶體佔用,並重疊計算和通訊。

記憶體節省:通過消除梯度和優化器狀態冗餘,我們將記憶體佔用進一步降低到2ψ+ \(\frac{14ψ}{N_d}\)≈ 2Ψ. 如圖1中的示例所示,7.5 B引數模型使用Pos+g和64路DP(Nd=64)時只需要16.6 GB記憶體,而使用標準DP時需要120 GB記憶體。當\(N_d\)較大時,模型狀態的記憶體需求從2ψ+14ψ=16ψ位元組減少到2ψ+\(\frac{14ψ}{N_d}\)位元組≈ 2ψ位元組,減少8倍。

6.3 Pp: 引數分割槽

就像優化器狀態和梯度一樣,每個程式只儲存與其分割槽對應的引數。當正向和反向傳播需要其分割槽外的引數時,會通過broadcast操作從適當的資料並行程式接收這些引數。雖然乍一看,這可能會導致顯著的通訊開銷,但我們發現,這種方法只會將基線DP系統的總通訊量增加到1.5倍,同時實現與Nd成比例的記憶體減少。

記憶體節省:通過引數分割槽,我們將ψ個引數的記憶體佔用從16ψ降低到 \(\frac{16ψ}{N_d}\)。 如圖1中的示例所示,7.5 B引數模型使用\(P_{os+g+p}\)和64路DP(Nd=64)時只需要1.9 GB記憶體,而使用標準DP時需要120 GB記憶體。

這有著深刻的含義:只要有足夠數量的裝置來共享模型狀態,ZeRO-DP就可以適合任意大小的模型。

6.4 對模型大小的影響

分割槽Pos、Pos+g和Pos+g+p的三個階段分別將模型狀態下每個資料並行程式的記憶體消耗減少了4倍、8倍和\(N_d\)倍。表1分析了幾個示例模型在不同DP程度下,ZeRO-DP 3個階段下的模型狀態記憶體消耗。

如果不使用ZeRO,則無論DP程度如何,記憶體消耗都等於表中的第一行。注意,當Nd=64時,ZeRO可以分別使用Pos、Pos+g和Pos+g+p來訓練引數高達7.5B、14B和128B的模型。當Nd=1024時,啟用所有優化的ZeRO(Pos+g+p)可以訓練具有1萬億個引數的模型!或者可能是任意大小的模型!如果沒有ZeRO,DP可以執行的最大模型的引數才不到15億個。

0x07 深入 ZeRO-R

7.1 \(P_a\): 將 Activation Checkpointing 分割槽

正如前面所討論,MP 在設計上就要求複製啟用,從而在模型並行GPU之間產生啟用的冗餘副本。ZeRO通過對啟用進行分割槽來消除這種冗餘,並且在啟用用於計算之前,才只以一個啟用層的副本形式將它們一次性具化。

更具體地說,一旦計算了模型中一個層的前向傳播,輸入啟用將在所有模型並行過程中進行分割槽,直到在反向傳播中再次需要它。此時,ZeRO使用all gather操作重新具化啟用的複製副本。我們將此優化稱為Pa。它與啟用檢查點一起工作,只儲存分割槽的啟用檢查點,而不是複製副本。此外,在非常大的模型和非常有限的裝置記憶體的情況下,這些分割槽的啟用檢查點也可以解除安裝到CPU上,以額外的通訊成本將啟用記憶體開銷降低到幾乎為零,我們稱之為\(P_{a+cpu}\)

通過分割槽啟用檢查點,ZeRO將啟用佔用空間減少了一個與MP程度成比例的因子。考慮訓練一個100B模型,其批大小為32,序列長度為1024,MP的度數為16。如果我們為每個轉換器層檢查一個啟用,那麼僅儲存啟用檢查點就需要每個GPU大約33 GB的記憶體。但如果Pa為零,則每個GPU的容量可以減少到2GB左右。此外,這個2GB可以解除安裝到CPU上,從而將啟用的記憶體佔用減少到幾乎為零。

7.2 CB: 固定大小緩衝區

ZeRO仔細選擇臨時資料緩衝區的大小,以平衡記憶體和計算效率。在訓練期間,某些操作的計算效率可能高度依賴於輸入大小,輸入越大,效率越高。例如,一個大的all-reduce操作比一個小的操作獲得更高的頻寬。因此,為了獲得更好的效率,高效能庫(如NVIDIA Apex或Megatron)在應用這些操作之前將所有引數融合到單個緩衝區中。然而,融合緩衝器的記憶體開銷與模型大小成正比。例如,對於3B引數模型,32位融合緩衝區將需要12GB記憶體。為了解決這個問題,當模型變得太大時,我們只需使用一個高效能的固定大小融合緩衝區(Constant Size Buffers)。通過這樣做,緩衝區大小不依賴於模型大小,並且通過保持足夠大的緩衝區大小,我們仍然可以實現良好的效率。

7.3 MD: 記憶體碎片整理

模型訓練中的記憶體碎片是啟用檢查點和梯度計算的結果。在帶有啟用檢查點的前向傳播期間,只有選定的啟用被儲存用於反向傳播,而大多數啟用被丟棄,因為它們可以在反向傳播期間重新計算。這將建立短期記憶體(丟棄的啟用)和長期記憶體(檢查點的啟用)的交錯,導致記憶體碎片。類似地,在反向傳播期間,引數梯度是長生命週期的,而啟用梯度和計算引數梯度所需的任何其他緩衝區是短生命週期的。同樣,這種短期記憶體和長期記憶體的交錯也會導致記憶碎片。

當有足夠的記憶體可用時,有限記憶體碎片通常不是問題,但對於使用有限記憶體執行的大型模型訓練,記憶體碎片會導致兩個問題,i)由於缺乏連續記憶體,即使有足夠的可用記憶體也會導致OOM,ii)由於記憶體分配器花費大量時間搜尋連續記憶體塊以滿足記憶體請求,導致效率低下。

ZeRO通過為啟用檢查點和漸變預先分配連續記憶體塊,並在生成時將它們複製到預先分配的記憶體中,動態地進行記憶體碎片整理。MD不僅使ZeRO能夠以更大的批量訓練更大的模型,而且還可以在記憶體有限的情況下提高訓練效率。

0x08 ZeRO-DP 通訊量分析

由於ZeRO通過消除記憶體冗餘提高了可以訓練模型的大小,所以很自然地會有疑問,是否在用通訊量換取記憶體效率的問題。換句話說,與基線DP方法相比,ZeRO-powered到DP方法的通訊量是多少?答案分為兩個部分:i)ZeRO-DP使用Pos和Pg的時候並不會產生額外的通訊,而可實現8倍的記憶體縮減;ii)ZeRO-DP在使用Pos和Pg之外的Pp時,最多會產生1.5倍的通訊,但同時進一步將記憶體佔用減少了Nd倍。

8.1 資料並行通訊量

資料並行訓練期間,在反向傳播結束,而在計算下一步的更新之前,會對所有資料並行程式的梯度進行平均。平均操作是使用all-reduce來完成的。對於大型模型,all- reduce通訊完全受通訊頻寬的限制,因此,我們的分析僅限於每個資料並行程式之間傳送和傳送的總通訊量。

all-reduce的最新實現一般採用兩步方法,第一步是reduce-scatter操作,它規約了不同程式上資料的不同部分。下一步是all gather操作,其中每個程式收集所有程式上規約的資料。這兩個步驟的結果就是一個all-reduce操作。“reduce-scatter”和“all-gather”都是使用流水線方法實現的,這會導致總共ψ個元素(假設資料包含ψ個元素)的資料移動。因此,標準DP在每個訓練步驟中產生2ψ個資料移動。

8.2 ZeRO-DP 通訊量

8.2.1 使用 Pos+g 的通訊量

使用梯度分割槽之後,每個程式只儲存梯度的一部分,這是更新其相應的引數分割槽所必需的。因此,與all-reduce不同,ZeRO只需要在梯度上進行scatter-reduce操作,從而產生ψ的通訊量。在每個程式更新其負責的引數的分割槽後,將執行all-gather以從所有資料並行程式收集所有更新的引數。這也導致通訊量為ψ。因此,每個訓練步驟的總通訊量為ψ+ψ=2ψ,與基線DP完全相同。

8.2.2 使用 Pos+g+p 的通訊量

在引數分割槽之後,每個資料並行程式只儲存它負責更新的引數。因此,在前向傳播期間,它需要接收所有其他分割槽的引數。但是,這可以通過流水線來避免記憶體開銷。在計算與特定分割槽對應的模型部分的前向傳播之前,負責該分割槽的資料並行程式可以將權重廣播給所有資料並行程式。一旦該分割槽的前向傳播完成,就可以丟棄這些引數。因此,總通訊量為\(\frac{ψ×N_d}{N_d}=ψ\)。換言之,我們依靠在整個正向傳播中傳播來重新安排引數的all-gather,並在使用引數後丟棄這些引數。但是請注意,對於反向傳播,需要再次進行此all-gather(但是以相反的順序)。

因此,總通訊量是reduce-scatter和 all-gather所產生的通訊量的總和,總體積為3ψ,是基線的1.5倍。梯度和引數分割槽都利用了這樣一種洞察,即並非所有的梯度和引數狀態都是始終需要的,而是通過明智地傳遞狀態來優化記憶體。

0x09 ZeRO-R 通訊分析

我們將ZeRO-R中的分割槽啟用檢查點(Pa)的通訊量與基線MP進行了比較,結果表明,Pa引起的通訊量增加通常不到基線MP的十分之一。此外,我們分析了Pa的通訊開銷與DP通訊量的關係,以確定Pa通過允許更大的批量和減少DP通訊來提高效率的情況。我們利用這種分析來決定是否以及何時應用Pa以及Pa+cpu。

分割槽啟用檢查點的通訊量權衡取決於模型大小、檢查點策略和MP策略。為了分享具體的見解,我們在使用SOTA MP方法(Megatron-LM)實現的模型背景下進行分析。

在帶有啟用檢查點的Megatron-LM中,每個transformer在正向傳播中執行兩個大小為batch×seq_length×hidden _dim 的all-reduce操作,兩個all-reduce操作用於正向重新計算,另外兩個all-reduce操作用於反向傳播。每個塊的總通訊量為 12 × seq length × hidden dim,因為all reduce的通訊量為2 × message size。

當ZeRO-R對啟用檢查點進行分割槽時,需要在每個啟用檢查點上向前重新計算反向傳播之前執行額外的all gather操作。通常,我們檢查每個transformer塊的輸入啟用,每個轉換器塊需要一個all gather。因此,通訊開銷Pa為seq length ∗ hidden dim,因為所有聚集的通訊量為message size。因此,Pa的總通訊開銷小於模型並行原始通訊量的10%。

當MP與DP結合使用時,Pa可用於將資料並行通訊量減少一個數量級,而模型並行通訊量增加10%,並在資料並行通訊成為效能瓶頸時顯著提高效率。請注意,Pa將啟用記憶體消耗降低了MP並行度,從而允許按比例增加批處理大小。對於大型模型,MP可以大到16個(DGX-2節點上的#GPU),允許批量大小最多增加16倍。資料並行訓練的通訊量與批量大小成反比。因此,由於Pa導致批量大小增加一個數量級可能導致資料並行通訊量減少一個數量級。

最後,如果採用Pa+cpu,分割槽啟用檢查點將解除安裝到cpu,啟用記憶體需求將減少到幾乎為零,與Pa相比,cpu記憶體之間增加了2倍的資料移動。在極端情況下,DP通訊量是主要瓶頸,因為即使使用Pa,批大小也很小,只要cpu資料傳輸開銷小於DP通訊量開銷,Pa+cpu就可以通過增加批處理大小來提高效率,這通常適用於小批處理大小。

在給定模型和硬體特性的情況下,我們利用上述分析來決定是否以及何時應用Pa和Pa+cpu。

0xFF 參考

論文解讀系列第十三篇:ZeRO——面向萬億級引數的模型訓練方法

[譯] DeepSpeed:所有人都能用的超大規模模型訓練工具

DeepSpeed: Extreme-scale model training for everyone

相關文章