[原始碼分析] Facebook如何訓練超大模型---(1)

羅西的思考發表於2022-01-17

[原始碼分析] Facebook如何訓練超大模型---(1)

0x00 摘要

我們在前文介紹過,微軟 ZeRO 可以對一個萬億引數模型可以使用 8 路模型並行、64 路管道並行和 8 路資料並行在 4,096 個 NVIDIA A100 GPU 上進行擴充套件。

而FSDP(Fully Sharded Data Parallel)是Facebook 深度借鑑微軟ZeRO之後提出的PyTorch DDP升級版本,可以認為是對標微軟 ZeRO,其本質是 parameter sharding。Parameter sharding 就是把模型引數等切分到各個GPU之上。我們會以 Google,微軟和 Facebook 的論文,部落格以及程式碼來進行學習分析。

本系列其他文章如下:

[原始碼解析] PyTorch 分散式之 ZeroRedundancyOptimizer

[論文翻譯] 分散式訓練 Parameter sharding 之 ZeRO

[論文翻譯] 分散式訓練 Parameter Sharding 之 Google Weight Sharding

0x01 簡介

1.1 FAIR & FSDP

大規模訓練人工智慧模型並不容易。除了需要大量的計算能力和資源外,訓練非常大的模型背後還有相當大的工程複雜性。Facebook人工智慧研究(FAIR)工程部一直致力於構建工具和基礎設施,以使大型人工智慧模型的培訓變得更容易。

Fully Sharded Data Parallel(FSDP)是FAIR引入的最新工具。它將AI模型的引數在資料並行worker之間進行切分,並且可以選擇將部分訓練計算解除安裝到CPU。顧名思義,FSDP是一種資料並行訓練演算法。儘管引數被分片到不同的GPU,但每個微批次資料的計算對於每個GPU worker來說仍然是本地的。這種概念上的簡單性使FSDP更易於理解,並且更適用於各種使用場景(與層內並行和流水線並行相比)。與optimizer state+gradient sharding資料並行方法相比,FSDP在訓練過程中通過通訊和計算重疊對模型引數進行更均勻的切分,具有更好的效能。

FSDP可以使用更少的GPU更有效地訓練數量級更大的模型。FSDP已在FairScale庫中實現,允許工程師和開發人員使用簡單的API擴充套件和優化其模型的培訓。在Facebook,FSDP已經被整合並測試,用於訓練一些NLP和Vision模型。

1.2 大規模訓練計算能力需求

大規模模型訓練需要大量的計算資源,比如OpenAI的GPT-3 擁有1750億個引數。其訓練估計需要355年的GPU時間,相當於1000個GPU連續工作4個月以上。

除了需要大量計算和工程資源外,大多數的訓練擴充套件方法都會帶來額外的通訊成本,並且需要工程師仔細評估記憶體使用和計算效率之間的權衡。例如,典型的資料並行培訓要求在每個GPU上維護模型的冗餘副本,而模型並行培訓為在worker(GPU)之間移動啟用引入了額外的通訊成本。

相比之下,FSDP相對而言沒有做任何權衡。它通過在GPU上分割模型引數、梯度和優化器狀態來提高記憶體效率,並通過分解通訊並將其與前向和後向過程重疊來提高計算效率。FSDP產生與標準分散式資料並行(DDP)培訓相同的結果,並提供易於使用的介面,該介面是PyTorch分散式資料並行模組的替代品。Facebook 的早期測試表明,FSDP可以擴充套件到數萬億個引數。

0x02 FSDP 如何工作

在標準DDP訓練中,每個worker處理一個單獨的批次,並使用all-reduce對worker之間的梯度進行彙總。雖然DDP已經變得非常流行,但它佔用的GPU記憶體比它實際需要的要多,因為模型權重和優化器狀態在所有DDP worker中都有一個副本

2.1 全引數分片

減少副本的一種方法是應用全引數分片( full parameter sharding)的過程,其中僅提供區域性計算所需的模型引數、梯度和優化器的子集。這種方法的一個實現 ZeRO-3 已經被微軟所普及。

解鎖全引數切分的關鍵是:我們可以把DDP之中的all reduce操作分解為獨立的 reduce-scatter 和 all-gather 操作

圖來自 :https://engineering.fb.com/wp-content/uploads/2021/07/FSDP-graph-2a.png?w=1024

“All-reduce”是“reduce-scatter”和“all-gather”的組合。聚合梯度的標準 “All-reduce”操作可以分解為兩個單獨的階段:“reduce-scatter”和“all-gather”。

  • “reduce-scatter”階段,在每個GPU上,會基於rank 索引對 rank 之間相等的塊進行求和。
  • “all-gather”階段,每個GPU上的聚合梯度分片可供所有GPU使用。

通過重新安排reduce scatter和all gather,每個DDP worker只需要儲存一個引數分片和優化器狀態。

2.2 比對

下圖顯示了標準DDP訓練(上半部分)和FSDP訓練(下半部分):

Full Sharded Data Parallel graph

  • 在標準的資料並行訓練方法中,每個GPU上都有一個模型副本,向前和向後傳遞的序列只在自己的資料分片上進行執行。在這些區域性計算之後,每個區域性過程的引數和優化器與其他GPU共享,以便計算全域性權重更新。

  • 在FSDP中:

    • Model shard :每個GPU上僅存在模型的分片
    • All-gather :每個GPU通過all-gather從其他GPU收集所有權重,以在本地計算前向傳播。就是論文思路Pp下劃線部分
    • Forward(local):在本地進行前向操作。前向計算和後向計算都是利用完整模型。
    • All-gather :然後在後向傳播之前再次執行此權重收集。就是論文思路Pp之中的下劃線部分
    • Backward(local):本地進行後向操作。前向計算和後向計算都是利用完整模型,此時每個GPU上也都是全部梯度
    • Reduce-scatter :在向後傳播之後,區域性梯度聚合並且通過 reduce-scatter 在各個GPU上分片,每個分片上的梯度是聚合之後本分割槽對應的那部分,就是論文思路Pg之中的下劃線部分。
    • Update Weight(local):每個GPU更新其區域性權重分片。

為了最大限度地提高記憶體效率,我們可以在每層向前傳播後丟棄全部權重,為後續層節省記憶體。這可以通過將FSDP包裝應用於網路中的每一層來實現(通過設定reshard_after_forward=True)。

下面是虛擬碼實現:

FSDP forward pass:
    for layer_i in layers:
        all-gather full weights for layer_i # 權重
        forward pass for layer_i
        discard full weights for layer_i # 權重

FSDP backward pass:
    for layer_i in layers:
        all-gather full weights for layer_i # 權重
        backward pass for layer_i
        discard full weights for layer_i # 權重
        reduce-scatter gradients for layer_i # 梯度

2.3 梳理

我們結合論文思路再來梳理一下 FSDP。

2.3.1 思路

論文思路如下:

  • Pp: Parameter Partitioning,每個程式只儲存與其分割槽對應的引數。正向和反向傳播需要其分割槽外的引數時,會通過broadcast操作從適當的資料並行程式接收這些引數。雖然乍一看,這可能會導致顯著的通訊開銷,但我們發現,這種方法只會將基線DP系統的總通訊量增加到1.5倍,同時實現與Nd成比例的記憶體減少。
  • Pos : Optimizer State Partitioning對於一個\(N_d\)並行度的DP來說,我們將優化器狀態分組到\(N_d\)個相等的分割槽中,這樣第i個資料並行程式只更新與第i個分割槽對應的優化器狀態。因此,每個資料並行過程只需要儲存和更新總優化器狀態 的$ \frac{1}{N_d}\(,然後只更新\) \frac{1}{N_d}$個引數。在每個訓練步驟結束時,我們會執行一個跨資料並行程式的all-gather操作,以獲得跨所有資料並行程式的完全更新的引數。
  • Pg: Gradient Partitioning由於每個資料並行程式只負責更新其相應的引數分割槽,因此,每個節點僅僅對自己負責的那部分引數的梯度進行規約。在歸併之後,每個節點只需要自己引數分割槽對應的梯度,對於其他的梯度不再需要,所以它們的記憶體可以被釋放。這將梯度的記憶體佔用從2ψ位元組縮減到 \(\frac{2ψ}{N_d}\)。實際上,這是一種 Reduce-Scatter操作,不同引數的梯度被減少到不同的程式之中。

總結一下:因為模型引數被分割槽,所以引數梯度(在框架實現中,梯度往往是引數的成員變數)自然就被分割槽了。分割槽的引數被設定到優化器之中,所以優化器只會優化本分割槽的引數,所以優化器狀態自然就是分割槽之後的。注意,在前向傳播和後向傳播時候,每個GPU都是用全部模型來計算,得到的梯度也是全部的梯度,只是儲存時候只儲存自己分割槽對應的部分

2.3.2 流程步驟

我們再來展示一下具體流程。假設資料並行度為 n,則有 n 個GPU,那麼每個GPU之上儲存總模型引數的 1/n,同時梯度,優化器狀態就自然被分割槽了,每個GPU之上還有資料並行。

  • 起始狀態:每個GPU之上是\(P_n, G_n, O_n\)。注意,因為本GPU上模型是\(P_n\),所以\(O_n\)自然就對應了\(P_n\),就自動分片了。
  • 正向計算時候,每個 \(GPU_n\) 都把自己負責的引數 \(P_n\) 廣播給其他所有的 GPU,前向計算之後,每個 \(GPU_n\) 都得到自己輸入訓練資料 \(data_n\) 的損失 \(loss_n\)
  • 反向計算時候,每個 \(GPU_n\) 也都把自己負責的引數 \(P_n\) 廣播給其他所有的 GPU,最後計算得到對應於資料 \(data_n\)的梯度 \(G_n\)
  • 將梯度 \(G_n\)聚合到對應的\(GPU_n\)上,這時候\(GPU_n\) 上的梯度就是 \(reduce(G_0, ..., G_n)\) 之中自己rank對應的部分。注意,梯度聚合過程則使用了 reduce-scatter,因為每個gpu只需要更新自己負責的部分\(P_n,G_n,O_n\),所以不需要進行all-gather了。

0x03 How to use FSDP

目前,FAIR 提供四種解決方案來使用FSDP,以適應不同的需求。

3.1 在語言模型中使用FSDP

對於語言模型,可以在通過以下新引數,在 fairseq framework 之中支援 FSDP:

  • –ddp-backend=fully_sharded: 通過FSDP啟用完全切分。
  • –cpu-offload: 將優化器狀態和FP32模型副本解除安裝到cpu(與–optimizer=cpu_adam結合使用)。
  • –no-reshard-after-forward: 提高大模型訓練速度 (1B+ params) ,類似於 ZeRO stage 2。
  • 其他常見選項 (–fp16, –update-freq, –checkpoint-activations, –offload-activations, etc.) 還是繼續正常工作。

具體請參閱fairseq教程

3.2 在計算機視覺模型之中使用FSDP

對於計算機視覺模型, VISSL 中可以支援FSDP,並在regnet架構上進行了測試。像BatchNorm和ReLU這樣的層已經被無縫地處理並已經測試過其收斂性。可以使用下面選項來啟用 FSDP。

  • config.MODEL.FSDP_CONFIG.AUTO_SETUP_FSDP=True
  • config.MODEL.SYNC_BN_CONFIG.SYNC_BN_TYPE=pytorch
  • config.MODEL.AMP_PARAMS.AMP_TYPE=pytorch

在如下連結可以繼續研究 this section

3.3 在PyTorch Lightning使用FSDP

為了更容易地與更通用的用例整合,PyTorch Lightning已經將FSDP作為beta功能。[此教程](https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#fully-sharded training) 包含一個關於如何將FSDP外掛與PyTorch Lightning一起使用的詳細示例。如下所示,新增plugins='fsdp'可以啟用它。

model = MyModel()
trainer = Trainer(gpus=4, plugins='fsdp', precision=16)
trainer.fit(model)

trainer.test()
trainer.predict()

3.4 直接從FairScale使用FSDP庫

FSDP的主要開發庫是FairScale.。您可以通過以下示例直接使用FairScale的FSDP,只需更換DDP。

from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
...
# sharded_module = DDP(my_module)
sharded_module = FSDP(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
for sample, label in dataload.next_batch:
  out = sharded_module(x=sample, y=3, z=torch.Tensor([1]))
  loss = criterion(out, label)
  loss.backward()
  optim.step()

FairScale中的FSDP庫為大規模訓練的許多重要方面提供了選項。當你希望使用FSDP的全部功能,你可以自行研究如下方面。

  1. 模型封裝:為了最大限度地減少短期內的GPU記憶體需求,使用者需要以巢狀方式封裝模型。這增加了複雜性,但是在移植現有PyTorch模型程式碼時非常有用。
  2. 模型初始化:與DDP不同,FSDP不會在GPU工作程式之間自動同步模型權重。這意味著必須小心地進行模型初始化,以便所有GPU worker具有相同的初始權重。
  3. 優化器設定:由於分片和包裝,FSDP只支援某些型別的優化器和優化器設定。特別是,如果模組被FSDP包裝,並且其引數被展平為單個張量,則使用者不能對此類模組中的不同引數組使用不同的超引數。
  4. **混合精度 **:FSDP支援FP16主權重的高階混合精度訓練,以及在梯度上FP16型別的reduce和scatter。但是,模型的某些部分可能只有在使用全精度時才收斂,在這些情況下,需要額外的wrapping,以便有選擇地以全精度執行模型的某些部分。
  5. 狀態檢查點和推斷:當模型規模較大時,儲存和載入模型狀態可能會變得很困難。FSDP支援多種方法使該任務成為可能,但這些方法是有代價的。
  6. 最後,FSDP通常與啟用檢查點函式一起使用,如checkpoint_wrapper 。使用者可能需要仔細調整啟用檢查點策略,以便在有限GPU記憶體空間內容納一個大型模型。

0x04 記憶體管理

我們接下來看看FSDP如何管理記憶體。

FairScale提供了受ZeRO <https://arxiv.org/pdf/1910.02054.pdf> 啟發的演算法:當使用資料並行訓練時,您需要在計算/通訊效率方面權衡記憶體的使用。另一方面,在使用模型並行訓練時,需要為了記憶體而權衡計算/通訊。

模型訓練的記憶體使用通常分為兩類:

  • 模型狀態:優化器狀態、梯度、引數。

  • 剩餘狀態:啟用、臨時緩衝區、碎片記憶體。

為了減少模型狀態下的冗餘,ZeRO提出了三種不同的演算法。這些在FairScale中實現為優化器狀態分片(Optimizer State Sharding,即OSS)、分片資料並行(Sharded Data Parallel,即SDP)和最終完全分片資料並行(Fully Sharded Data Parallel,即FSDP)。讓我們深入瞭解每一個演算法的實際機制,並理解它們為什麼能夠節省記憶體。

4.1 Optimizer State Sharding (OSS)

FairScale已經實現了與優化器記憶體相關的記憶體優化 OSS。

像Adam這樣的優化器通常需要保持動量、方差。即使可以使用FP16精度的引數和梯度進行訓練,引數和梯度也需要儲存為FP32精度。當每個rank更新完整模型時,這意味著相當大一部分記憶體被優化器狀態的冗餘表示所佔用。

為了克服這種冗餘,優化器狀態分片需要將模型優化步驟劃分在不同的rank之間,以便每個rank只負責更新模型的對應分片。這反過來又確保優化器狀態在每個rank上小得多,並且它不包含跨rank的冗餘資訊。

4.1.1 訓練流程

訓練流程可以從DDP的執行流程做如下修改:

  1. wrapped optimizer根據引數大小(而不是使用順序)以貪心演算法方式來分割優化器狀態。這是為了確保每個rank具有幾乎相同的優化器記憶體佔用。

  2. 訓練過程類似於PyTorch的分散式資料並行(DDP)的過程。在每個rank上完成前向傳播,然後是向後傳播。在後向傳播過程中,使用allreduce同步梯度。

  3. 每個rank只更新它負責的優化器分配狀態引數,然後丟棄其餘的。

  4. 更新後,將執行broadcast或allgather操作,以確保所有rank都收到最新更新的引數值。

當您使用具有附加狀態的優化器(如Adam)時,OSS非常有用。如果您使用的是SGD或任何記憶體佔用有限的優化器,那麼在使用多個節點時,由於步驟4中的額外通訊,您可能會看到速度減慢。在第2步的allreduce過程中,也有一些用於儲存梯度的浪費記憶體,這些記憶體隨後被丟棄。

4.1.2 最佳實踐

  • OSS公開了一個broadcast_fp16 flag,您可能應該在多節點作業中使用它。在單節點實驗中通常不需要這樣做。

  • 如果您的模型在大小方面極不平衡(例如,存在一個巨大的張量),那麼這種方法將不會有很大幫助,而張量切分選項,如'fairscale.nn.FullyShardedDataParallel'將更可取。

  • 3.OSS應該是DDP環境中的一個臨時解決方案,其與大多數DDP功能保持相容。

4.1.3 效能

  • 在單個節點上,OSS應該總是比vanilla PyTorch快,記憶體節省會因使用的優化器而異

  • 當使用多個節點時,OSS也可以比vanilla PyTorch快或慢,具體取決於所使用的優化器和可選標誌(如上文提到的broadcast_fp16、梯度壓縮、梯度累積)

  • 如果您的實驗可以使用更大的batch size,則採取更大的batch size並減少所涉及的rank數通常是有益的,或者使用梯度累積,因為這樣可以降低通訊成本。

4.2 Optimizer + Gradient State Sharding

雖然OSS解決了優化器中的冗餘問題,但依然存在梯度聚合計算的重複以及存在用於梯度的額外記憶體。為了克服冗餘梯度記憶體,我們可以使用梯度分片或ZeRO-2。這已由FairScale中的分片資料並行(SDP)API實現。

為了啟用梯度分片,每個 rank 都被分配一組引數,它們負責管理優化器狀態以及梯度聚合。通過將一個模型分片分配給一個給定的rank,我們確保梯度被規約到特定的rank,而這些rank又負責相應的更新。因此這減少了通訊和記憶體使用。

4.2.1 訓練過程

訓練過程如下:

  1. 與之前一樣,包裝的優化器在不同的列組中分割引數。

  2. 該模型現在使用分片資料並行(SDP)包裝器進行包裝,該包裝器允許我們在訓練過程中新增適當的hook並維護狀態。

  3. SDP關注於可訓練的引數,併為每個引數新增了一個反向hook。

  4. 在反向傳播過程中,梯度將規約到指定rank,rank是在 1 中作為切分過程的一部分指定的。使用reduce op代替allreduce op,從而減少通訊開銷。

  5. 每個rank更新其負責的引數。

  6. 更新後,將執行廣播或allgather,以確保所有rank都收到最新更新的引數值。

OSS和SDPAPI都允許您減少用於梯度和優化器狀態的記憶體,但是如果網路緩慢,則可能存在額外的通訊成本。當遇到記憶體不足(OOM)問題時,可以把OSS和SDP作為第一步嘗試。

4.2.2 最佳實踐

  • 如果使用多個節點,請通過指定reduce_buffer_size 引數確保SDP正在使用reduce buffers。改變它們的大小可能是一個優化目標,最佳配置可能取決於互連狀況。

  • 如果在單個節點上,通常最好不要使用'reduce_buffer_size',因為它會帶來延遲成本,但不會增加記憶體。將此值設定為0表示不使用此功能。

  • 如果您的實驗可以使用更大的batch size,則採取更大的batch size並減少所涉及的rank數通常是有益的,或者使用梯度累積,因為這樣可以降低通訊成本。

4.3 Optimizer + Gradient + Horizontal Model Sharding

為了進一步優化訓練並實現更大的記憶體節省,我們需要啟用引數切分。

引數切分類似於梯度和優化器狀態,即,每個資料並行rank負責模型引數的一個分片。FairScale通過完全分片資料並行(FSDP)API實現引數分片,該API深受 ZeRO-3 <https://arxiv.org/pdf/1910.02054.pdf>的啟發。

引數分片有兩個如下關鍵點:

  • Allreduce操作可以分為reduce和allgather,類似於以前的分片技術(優化器狀態和梯度)。

  • 可以使用FSDP API包裝各個層,該API允許我們在給定例項中將單個層所需的所有引數引入給定GPU,計算前向傳遞,然後丟棄不屬於該rank的引數。

使用FSDP很簡單,只需要在程式碼中簡單地替換原來的DDP即可。注意:FSDP目前要求模型是一個nn.Sequential模型。

from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

from fairscale.experimental.nn.offload import OffloadModel

num_inputs = 8
num_outputs = 8
num_hidden =  4
num_layers =  2
batch_size =  8

transform = ToTensor()
dataloader = DataLoader(
    FakeData(
        image_size=(1, num_inputs, num_inputs),
        num_classes=num_outputs,
        transform=transform,
    ),
    batch_size=batch_size,
)

model = torch.nn.Sequential(
    torch.nn.Linear(num_inputs * num_inputs, num_hidden),
    *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),
    torch.nn.Linear(num_hidden, num_outputs),
)

4.3.1 訓練過程

具體訓練過程如下:

  • 在開始計算特定層之前,allgather模型每個層的正向傳播所需的引數。

  • 計算向前計算。

  • 在特定層開始反向傳遞之前,allgather模型每個層反向傳播所需的引數。

  • 計算向後傳播。

  • 規約梯度,以便在負責相應引數的rank上累積聚合梯度。

  • 讓每個rank使用聚合梯度更新已分配給它的引數。

有了FSDP,在使用API進行檢查點設定和儲存優化器狀態時,需要做一些小的更改。鑑於優化器狀態和引數的分片性質,任何旨在儲存模型狀態以供訓練或推理的API都需要考慮儲存所有worker的權重。FSDP實現所需的管道(required plumbing)以儲存所有worker的權重、儲存單個worker的權重以及儲存所有worker的優化器狀態。

FSDP還支援混合精度訓練,其中計算和通訊均以FP16精度進行。如果要減少在FP32中執行的操作(這是DDP的預設行為),則必須設定 fp32_reduce_scatter=True

為了進一步節省記憶體,FSDP支援將當前未使用的引數和梯度解除安裝到CPU上。這可以通過將“move_params_to_cpu”和“move_grads_to_cpu”設定為True來啟用。

4.3.2 最佳實踐

  • 對於FSDP,最好使用 model.zero_grad(set_to_none=True) ,因為它在單步執行後節省了大量記憶體。

  • torch.cuda.amp.autocast與FSDP完全相容。您需要將'mixed_precision'arg設定為True。

  • 如果與啟用檢查點相結合,則最好使用 FSDP(checkpoint_wrapper(module))而不是checkpoint_wrapper(FSDP(module)).。後者將導致更多的通訊,速度也會變慢。

  • FSDP與使用pointwise優化器的DDP相容,例如Adam、AdamW、ADADDelta、Adamax、SGD等。當使用non-pointwise優化器(例如Adagrad、Adafactor、LAMB等)時,sharding將導致略有不同的結果。

4.3.3 效能

  • 為了獲得最佳記憶體效率,請使用“auto_wrap”將網路中的每一層用FSDP進行封裝,並將 reshard_after_forward 設定為True。這樣速度會慢,但是視訊記憶體開銷最小。

  • 為了獲得最佳訓練速度,請將 reshard_after_forward 設定為False(不需要包裝每一層,但如果設定,則會進一步提高速度)。

支援,FSDP基本原理和如何使用我們已經介紹完畢,下一篇我們介紹其程式碼細節,看看究竟如何做到最大程度減少記憶體使用。

0xFF 參考

Fully Sharded Data Parallel: faster AI training with fewer GPUs

ZeRO & DeepSpeed:可以讓訓練模型擁有超過1000億個引數的優化(微軟)

Fully Sharded Data Parallel: faster AI training with fewer GPUs

https://github.com/microsoft/DeepSpeed

ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

Automatic Cross-Replica Sharding of Weight Update in Data-Parallel Training

相關文章