使用 PyTorch FSDP 微調 Llama 2 70B

HuggingFace發表於2023-12-12

引言

透過本文,你將瞭解如何使用 PyTorch FSDP 及相關最佳實踐微調 Llama 2 70B。在此過程中,我們主要會用到 Hugging Face Transformers、Accelerate 和 TRL 庫。我們還將展示如何在 SLURM 中使用 Accelerate。

完全分片資料並行 (Fully Sharded Data Parallelism,FSDP) 是一種訓練正規化,在該正規化中最佳化器狀態、梯度和模型引數都會被跨裝置分片。前向傳播時,每個 FSDP 單元執行 all gather 以獲取完整的權重,然後用它們進行計算並在計算後丟棄掉其他裝置的分片。隨後是反向傳播,然後就是損失計算。反向傳播時,每個 FSDP 單元執行 all gather 操作以獲取完整的權重,並執行計算以獲得本地 batch 的梯度。這些梯度透過 reduce scatter 在裝置上進行均值計算並分片,這樣每個裝置都可以更新其對應分片的引數。有關 PyTorch FSDP 的更多資訊,請參閱此博文: 使用 PyTorch 完全分片資料並行技術加速大模型訓練

FSDP 工作流

(圖源: 連結)

使用的硬體

節點數: 2,至少 1 個節點
每節點 GPU 數: 8
GPU 型別: A100
GPU 視訊記憶體: 80GB
節點內互聯: NVLink
每節點記憶體: 1TB
每節點 CPU 核數: 96
節點間互聯: AWS 的 Elastic Fabric Adapter (EFA)

微調 LLaMa 2 70B 面臨的挑戰

在嘗試使用 FSDP 微調 LLaMa 2 70B 時,我們主要遇到了三個挑戰:

  1. FSDP 會先載入整個預訓練模型,然後再對模型進行分片。這樣就意味著節點內的每個程式 (即 rank) 都會載入整個 Llama-70B 模型,因此需要 7048 GB ~ 2TB 的 CPU 記憶體,這個算式中 4 是每個引數所需位元組數,8 是每個節點的 GPU 數。這會導致 CPU 記憶體不足,進而導致程式終止。
  2. 使用 FULL_STATE_DICT 來儲存完整中間檢查點並將其解除安裝至 rank 0 的 CPU 記憶體中需要花費大量時間,且由於在此期間通訊庫需要無限期掛起等待儲存完成,因此經常會導致 NCCL 超時錯誤。然而,完全關掉這個選項也不好,因為在訓練結束時我們需要儲存完整的模型狀態字典,而不是 FSDP 式分片的狀態字典。
  3. 我們需要提高速度並減少視訊記憶體使用,以加快訓練並節約計算成本。

下文,我們主要討論如何一一解決上述挑戰,最終微調出一個 70B 的模型!

先列出重現結果所需的所有資源:

  1. 程式碼庫: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/training,程式碼中包含了使能 flash 注意力 V2 的熱補丁
  2. FSDP 配置檔案: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/fsdp_config.yaml
  3. SLURM 啟動指令碼 - launch.slurm : https://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25
  4. 模型: meta-llama/Llama-2-70b-chat-hf
  5. 資料集: smangrul/code-chat-assistant-v1 (混合了 LIMA 和 GUANACO 資料集,且已轉換為訓練所需的格式)

準備工作

首先按照 此步驟 安裝 Flash Attention V2。然後,安裝最新的 PyTorch nightly (CUDA ≥11.8)。接著,根據 此檔案 安裝其餘依賴軟體。在本文中,我們是從主分支安裝 🤗 Accelerate 和 🤗 Transformers 的。

微調

應對挑戰 1

PR 25107 和 PR 1777 解決了第一個挑戰,且無需使用者側更改任何程式碼。主要做的事情如下:

  1. 在所有 rank 上建立無權重的空模型 (使用 meta 裝置)
  2. 僅在 rank 0 上將狀態字典載入至模型
  3. 其他 rank 僅對 meta 裝置上的引數執行 torch.empty(*param.size(), dtype=dtype)
  4. 因此,只有 rank 0 上載入了完整的模型及權重,而所有其他 rank 上的權重是空的
  5. 設定 sync_module_states=True ,以便 FSDP 例項在訓練開始之前將權重廣播到各 rank

下面是在 2 個 GPU 上載入 7B 模型的輸出日誌片段,它測量了各個階段記憶體的消耗及其載入的模型引數量。我們可以觀察到,在載入預訓練模型時,rank 0 和 rank 1 的 CPU 峰值記憶體分別為 32744 MB1506 MB 。因此可知,僅有 rank 0 載入了預訓練模型,這就實現了 CPU 記憶體的有效利用。你可在 此處 找到完整日誌。

accelerator.process_index=0 GPU Memory before entering the loading : 0
accelerator.process_index=0 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=0 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=0 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=0 CPU Memory before entering the loading : 926
accelerator.process_index=0 CPU Memory consumed at the end of the loading (end-begin): 26415
accelerator.process_index=0 CPU Peak Memory consumed during the loading (max-begin): 31818
accelerator.process_index=0 CPU Total Peak Memory consumed during the loading (max): 32744

accelerator.process_index=1 GPU Memory before entering the loading : 0
accelerator.process_index=1 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=1 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=1 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=1 CPU Memory before entering the loading : 933
accelerator.process_index=1 CPU Memory consumed at the end of the loading (end-begin): 10
accelerator.process_index=1 CPU Peak Memory consumed during the loading (max-begin): 573
accelerator.process_index=1 CPU Total Peak Memory consumed during the loading (max): 1506

應對挑戰 2

該挑戰可以透過在配置 FSDP 時將狀態字典型別設為 SHARDED_STATE_DICT 來解決。設為 SHARDED_STATE_DICT 後,每個 rank 各自儲存各自 GPU 所需要的分片,這使得使用者可以快速儲存中間檢查點並快速從其恢復訓練。而當使用 FULL_STATE_DICT 時,第一個程式 (rank 0) 會用 CPU 收集整個模型,然後將其儲存為標準格式。

我們可以用以下命令建立相應的 accelerte 配置檔案:

accelerate config --config_file "fsdp_config.yaml"

fsdp 配置

你可以從此處獲取生成的配置檔案: fsdp_config.yaml。在該配置檔案中,分片策略是 FULL_SHARD 。我們使用 TRANSFORMER_BASED_WRAP 作為自動模型包裝策略,它使用 _no_split_module 來搜尋 transformer 塊名並自動進行巢狀 FSDP 包裝。我們使用 SHAARDED_STATE_DICT 把中間檢查點和最佳化器狀態儲存為 PyTorch 官方推薦的格式。同時,如上一節 應對挑戰 1 中所述,我們還需要確保訓練開始時用 rank 0 來廣播引數。從配置檔案中你還可以看到我們用的是 bf16 混合精度訓練。

那麼,在儲存最終檢查點時,如果將其儲存成單個檔案呢?我們使用的是以下程式碼段:

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

trainer.save_model(script_args.output_dir) # 或者 , 如果整個模型小於 50 GB (即 LFS 單檔案的最大尺寸),你還可以使用 trainer.push_to_hub() 把模型推到 hub 上去。

應對挑戰 3

為了加快訓練速度並減少視訊記憶體佔用,我們可以使用 flash 注意力並開啟梯度檢查點最佳化,從而在微調的同時節省計算成本。當前,我們用了一個熱補丁來實現 flash 注意力,具體程式碼可見 這兒

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 一文基於對底層硬體 (即 GPU) 的記憶體層次結構的深刻理解而引入了一種更快、更節省記憶體的無損注意力加速演算法。底層硬體在設計記憶體層次結構時,遵循的實踐原則是: 頻寬/速度越高的記憶體,其容量越小,因為它更貴。

根據博文 根據第一性原理讓深度學習效能起飛,我們可以發現,當前硬體上的注意力模組是 記憶體頻寬受限 的。原因是注意力機制 主要由逐元素操作 組成,如下左圖所示。我們可以觀察到,掩碼、softmax 和 dropout 操作佔用了大部分時間,而非需要大量 FLOP 的矩陣乘法。

注意力機制的效能瓶頸

(圖源: 連結)

這正是 flash 注意力解決的問題,其想法是 去除冗餘的 HBM 讀/寫操作。該演算法透過將所有內容保留在 SRAM 中,待執行完所有中間步驟後再將最終結果寫回到 HBM,即 運算元融合 來實現這一目的。下圖簡要描述了運算元融合是如何克服記憶體瓶頸的。

運算元融合

(圖源: 連結)

在前向和反向傳播過程中我們還使用了 平鋪 (Tiling) 最佳化技巧,將 NxN 大小的 softmax 分數計算切成塊,以克服 SRAM 記憶體大小的限制。在使用平鋪技巧時,我們會使用線上 softmax 演算法。同時,我們還在反向傳播中使用了 重計算 技巧,以大大降低在前向傳播過程中儲存整個 NxN softmax 分數矩陣所帶來的記憶體消耗。

如欲深入理解 flash 注意力,請參考博文 ELI5: FlashAttention根據第一性原理讓深度學習效能起飛 以及原始論文 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

綜合運用所有手段

你可參考 此指令碼,以在 SLURM 中用 Accelerate 啟動器執行訓練。下面還給出了一個等效命令,展示瞭如何使用 Accelerate 啟動器來執行訓練。請注意,該命令會覆蓋 fsdp_config.yaml 中的 main_process_ipmain_process_portmachine_ranknum_processes 以及 num_machines 配置。另一個需要重點注意的是,這裡的儲存是所有節點共享的。

accelerate launch \
    --config_file configs/fsdp_config.yaml \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$MACHINE_RANK \
    --num_processes 16 \
    --num_machines 2 \
    train.py \
    --model_name "meta-llama/Llama-2-70b-chat-hf" \
    --dataset_name "smangrul/code-chat-assistant-v1" \
    --max_seq_len 2048 \
    --max_steps 500 \
    --logging_steps 25 \
    --eval_steps 100 \
    --save_steps 250 \
    --bf16 True \
    --packing True \
    --output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --dataset_text_field "content" \
    --use_gradient_checkpointing True \
    --learning_rate 5e-5 \
    --lr_scheduler_type "cosine" \
    --weight_decay 0.01 \
    --warmup_ratio 0.03 \
    --use_flash_attn True

整個微調過程需要約 13.5 小時,下圖給出了訓練損失曲線。

訓練損失曲線

下例給出了使用上述模型完成的一段對話:

System Prompt: You are a helpful, respectful and honest assistant. Always answer as helpfully \
as possible, while being safe. Your answers should not include any harmful, \
unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that \
your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why \
instead of answering something not correct. If you don’t know the answer to a \
question, please don’t share false information.

- Human: What is Deep Learning? Explain like a Pirate.

+ Assistant: Arrr Grumete! Are ye lookin' fer a proper explanation o' Deep Learnin'?
+ Alright then, listen up me hearty! Deep Learnin' be a type o' machine learnin' that uses artificial neural networks to analyze data.
+ Aye, ye heard that right, mateys! Artificial neural networks be models o' the human brain, but instead o' using neurons, they use interconnected nodes or "artificial neurons" to process information.
+ These nodes work together to form layers o' processing units that can learn from large amounts o' data.
+ The more data these systems receive, the better they get at recognizing patterns and makin' decisions on their own.
+ So hoist the sails and set course for adventure with deep learnin', me hearties! It's a treasure trove o' knowledge just waitin' to be discovered.
+ Fair winds and following seas!

- Human: Now explain it like a chef.

+ Assistant: Certainly! Here's an explanation of deep learning from a culinary perspective:

+ Think of a recipe as a sequence of steps used to transform raw ingredients into a delicious dish.
+ Similarly, in deep learning, there are multiple layers of "ingredients" (or features) that are combined and transformed through various operations to produce a final output or prediction.
+ Each layer builds upon the previous one, refining and enhancing the representation of the input data until it reaches the desired outcome.
+ Just as different spices enhance the flavors of food, each layer adds depth and complexity to the model's understanding of the input, allowing it to make more accurate predictions.
+ And much like how chefs must carefully balance the amount of salt, sugar, and other seasonings in a recipe, the training process in deep learning involves adjusting the weights and biases of the network to achieve the perfect blend of accuracy and efficiency.
+ With enough practice and experimentation, anyone can become a master chef... I mean, a master of deep learning! Bon appétit!

整個對話使用的格式如下:

<|system|> system message <|endoftext|> <|prompter|> Q1 <|endoftext|> <|assistant|> A1 <|endoftext|> ...

總結

我們在多節點多 GPU 上使用 PyTorch FSDP 成功微調了一個 70B Llama 模型,並在此過程中解決了各種挑戰。我們看到了當前在 🤗 Transformers 和 🤗 Accelerates 中應如何初始化大模型從而有效克服 CPU 記憶體不足的問題。我們還給出瞭如何高效地儲存/載入中間檢查點,同時又能以易於使用的方式儲存最終模型的最佳實踐。為了加速訓練並減少 GPU 視訊記憶體使用,我們還強調了 flash 注意力和梯度檢查點機制的重要性。最後,我們向大家展示了在 🤗 Accelerate 上僅需要簡單的配置就可以在多節點多 GPU 上微調大模型。


英文原文: https://hf.co/blog/ram-efficient-pytorch-fsdp

原文作者: Sourab Mangrulkar,Sylvain Gugger,Lewis Tunstall,Philipp Schmid

譯者: Matrix Yao (姚偉峰),英特爾深度學習工程師,工作方向為 transformer-family 模型在各模態資料上的應用及大規模模型的訓練推理。

FSDP MFU (Model FLOPS Utilization) 相關討論: https://github.com/huggingface/blog/issues/1649

相關文章