一行程式碼Post-Train任意長序列!360智腦開源360-LLaMA-Factory

机器之心發表於2025-01-10
圖片
AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

專案核心開發者 Haosheng Zou 本科畢業於清華大學電子系,博士畢業於清華大學計算機系朱軍教授組,目前在 360 智腦從事長文字和強化學習等後訓練工作。開發者 Xiaowei Lv 目前在人民大學資訊學院研二在讀。Fenrui Xiao、Junchen Liu、Qi An 和 Xiaodong Sun 等在開發測試中亦有貢獻。

大模型長序列的處理能力已越來越重要,像複雜長文字任務、多幀影片理解任務、以及 OpenAI 近期釋出的 o1、o3 系列模型的高計算量模式,需要處理的輸入 + 輸出總 token 數從幾萬量級上升到了幾百萬量級。面對模型日益增長的長序列需求,在預訓練(Pre-Training)和後訓練(Post-Training)階段,所用的平臺框架都需要支援更長序列資料的訓練。不同於預訓練階段基於 Megatron-LM 定製開發的常見選擇,後訓練階段因後訓練演算法的多樣性(比如僅 DPO 就有幾十個變種)和訓練需求的靈活性,至今沒有一個框架同時在並行策略、後訓練演算法、GPU 視訊記憶體最佳化和簡單易用這 4 個方面上全部做到相容幷包。

在所有開源的後訓練框架中,LLaMA-Factory 是使用者最多的框架之一(GitHub star 數已 37k 多),保持長期迭代更新,支援豐富的模型和後訓練演算法,有各種 GPU 視訊記憶體最佳化技巧和簡單易用的方式。然而,LLaMA-Factory 在長序列後訓練上支援仍有所欠缺,尚不支援長序列的關鍵技術 —— 序列並行。
圖片
專案主頁:https://github.com/Qihoo360/360-LLaMA-Factory

最近,360 智腦基於 LLaMA-Factory 開源了 360-LLaMA-Factory,加入了序列並行功能,一行程式碼即可支援任意長序列的後訓練(Post-Training)—— 僅需額外指定序列並行一個引數:
sequence_parallel_size: 16

按需增加序列並行的 GPU 卡數,即可在任意長度的序列上 SFT 或 DPO。

360-LLaMA-Factory 的實現經過了嚴格的正確性驗證,已在主倉 Pull Request 中稽核過。正式合併進 LLaMA-Factory 主倉之前,可先使用 360-LLaMA-Factory。

1、專案背景與專案簡介

360 智腦早在 2023 年就開始了長文字大模型的研發,到目前為止已經成功應用於開源並更新了兩個版本的 360Zhinao-7B-Chat-360k 模型,以及近日釋出的長思維鏈推理模型 360gpt2-o1。在 360-LLaMA-Factory 中,我們將 360 智腦內部長序列後訓練能力系統性地整合進了 LLaMA-Factory 中,使用者僅需額外新增一行程式碼,即可進行理論上任意長度的長序列後訓練(增加序列並行的 GPU 卡數即可):
sequence_parallel_size: 16
在原先使用 LLaMA-Factory 的基礎上,只需額外增加一個引數

透過這種方式,360-LLaMA-Factory 將 LLaMA-Factory 的序列並行也做到了簡單易用和相容幷包,和 LLaMA-Factory 的其他功能完全相容。

粗粒度地測試 8 卡 80G 的全引數後訓練(不考慮除了 zero3-offload 和 gradient checkpointing 外的任何最佳化技巧),360-LLaMA-Factory 至少可以訓到 SFT 210k (7B) / 128k (72B) 和 DPO 84k (7B) / 46k (72B)。若加上注掉 logits = logits.float () 和 DPO 預計算等技巧,2 卡序列並行即可解決諸多常見的訓練需求。360-LLaMA-Factory 讓序列並行也真正成為了簡單好用、效果也好的後訓練工具。

作為開源社群的一份子,360-LLaMA-Factory 離不開 LLaMA-Factory、ring-flash-attention 和 EasyContext 等開源專案的開創性工作,我們的底層開發部分依賴了這些工作,但也有我們自己在具體實現方式上的不同和見解。我們相信我們的程式碼實現已做到儘可能好的模組化和儘可能少的原始程式碼修改,且嚴格檢查過正確性,因此也已向 LLaMA-Factory 主倉提交了 Pull Request,初步稽核透過。我們樂於同開源社群共建完善這項工作。

2、長序列及其後訓練

2.1 長序列大模型的訓練:預訓練 vs 後訓練

隨著大模型訓練資料長度的增長,預訓練和後訓練平臺框架都需要支援長序列資料訓練。
  • 預訓練階段,英偉達的 Megatron-LM 憑藉豐富高效的並行策略與出色的 GPU 視訊記憶體最佳化,成為主流框架,基於它的定製開發往往是最通用的解法, Megatron-LM 本身已實現了序列並行(Megatron-LM 稱之為 context parallelism,其他工作一般稱為 sequence parallelism)。

  • 後訓練階段情況相對複雜。後訓練演算法多樣,如 DPO 就有諸多變種,且訓練需求靈活多變,不同場景對演算法、資源、並行性等要求各異。因此,至今沒有一個框架能在並行策略、後訓練演算法、GPU 視訊記憶體最佳化和易用性這四個關鍵方面做到近乎完美的相容。雖有框架在部分方面表現尚可,但總體仍存在短板,這也限制了模型在長序列資料後訓練上的進一步發展。

2.2 長序列的通解 —— 序列並行及其難點

長序列後訓練面臨的關鍵瓶頸是:序列長度增加時,啟用視訊記憶體會大幅上升。雖然有 unsloth、liger kernel、LoRA 等多種降低視訊記憶體佔用的技巧,但均未從根本上解決序列長度增加的本質問題,其效果存在明確上限。

序列並行(sequence parallelism)被認為是解決長序列訓練問題的通解,它透過把一條長序列切分到不同的顯示卡上進行計算,從而避免了每張顯示卡處理過長的序列,從根本上解決了 “每張顯示卡處理的序列長度增加” 的問題。然而,序列並行的實現難度較大,需要在切分後的序列之間進行通訊計算 attention,需要侵入修改原始的 attention 函式。在開源的 Megatron-LM 中,序列並行也是所有並行策略中最後才新增的,LLaMA-Factory 之前還沒有支援序列並行。

2.3 序列並行後訓練的相關工作

我們調研了其他一些支援序列並行的開源框架,有些實現上有錯或小 bug、導致支援的後訓練演算法不全;有些更新維護不及時、訓練較新的模型不方便、顯示進度條等易用性不足。有的與 LLaMA-Factory 相比繼承依賴更少,支援功能較少但更乾淨、更適合定製開發,有不同的使用場景。此外,各家的序列並行具體實現也不盡相同。詳見下面的表 1 和 GitHub README,有未調研到的也請包涵並聯系 360-LLaMA-Factory。
圖片
表 1:一些支援序列並行的後訓練框架對比

3、360-LLaMA-Factory 框架解析

360-LLaMA-Factory 系統性地為 LLaMA-Factory 增加了序列並行的支援。以下將簡要介紹 360-LLaMA-Factory 框架中的模組化修改和執行流程。

3.1 360-LLaMA-Factory 的框架和模組化封裝

360-LLaMA-Factory 將序列並行的程式碼做到了儘可能好的模組化和儘可能少的原始程式碼修改。

我們認為序列並行本質上應認為是對模型的修改,因此在 model_args 中增加了引數並抽象為 apply_sequence_parallel 修改模型的函式。
# src/llamafactory/model/loader.py
sequence_parallel_group = apply_sequence_parallel(model_args) # 序列並行monkey patch,改動attention計算
...
model.sequence_parallel_group = sequence_parallel_group # 維護模型的序列並行組,不開則為None

相應地,資料處理部分也要相應地修改,我們將 zigzag ring attention 所需的資料處理抽象成了一個 decorator,裝飾原來的資料處理函式。背後,這會將先 shuffle、packing、預處理好的資料進一步做好序列並行的準備:先將每行 pad 或截斷到指定的訓練長度,再按 zigzag 切分並按順序寫入資料集,最後在訓練時用 SequentialSampler 讀取訓練資料。
# src/llamafactory/data/loader.py
@sequence_parallel_decorator
def get_dataset(...)

loss 計算則需要在 Trainer 中做序列並行組內的 reduce 彙總和計算。
# src/llamafactory/train/sft/trainer.py
dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(label_num, op=dist.ReduceOp.SUM, group=sp_group)
loss /= label_num
# src/llamafactory/train/dpo/trainer.py
dist.all_reduce(policy_chosen_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(policy_rejected_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(reference_chosen_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(reference_rejected_logps, op=dist.ReduceOp.SUM, group=sp_group)

3.2 360-LLaMA-Factory 的 SFT 和 DPOTrainer

除了統一的模組化抽象,序列並行也需要對 360-LLaMA-Factory 的 Trainer 稍做定製化的修改,以適配各底層庫。針對最普遍的後訓練需求 SFT 和 DPO(及其變種),我們對 360-LLaMA-Factory 中的 SFT 和 DPOTrainer 做了儘可能少且清晰的修改。

其中,dummy_forward 是因為我們發現基於目前的底層序列並行實現,在第一次 forward 時 DPO loss 不等於 log (sigmoid (0)),但學習率設為 0 時之後的 DPO loss 全都等於。因此,訓練最開始時先做且僅做一次假前傳,不對正式訓練迴圈造成任何影響。

從 SFT 和 DPO 的序列並行對比圖中,可以清晰地看出 360-LLaMA-Factory 序列並行帶來的改動。
圖片
圖 3:360-LLaMA-Factory SFT 序列並行對比
圖片
圖 4:360-LLaMA-Factory DPO 序列並行對比

4、360-LLaMA-Factory 效果驗證

內部 360-LLaMA-Factory 的早期版本已訓練了開源的 360Zhinao2-7B-Chat-360k。

為驗證本次開源的 360-LLaMA-Factory 的正確性,我們用總量為 30 條的小資料集,驗證了序列並行開與不開的對比情況下,訓練曲線的差別,以此來確保 360-LLaMA-Factory 所有實現的正確性。從下圖可見,序列並行對訓練曲線的影響幾乎可以忽略不計,DPO 稍有一定數值誤差,但我們也仔細檢查了該誤差與 DeepSpeed Ulysses 的誤差範圍一致,很可能部分是平行計算本身的隨機性導致的,亦可參考 ring-flash-attention 的詳細說明。
圖片
圖 5:360-LLaMA-Factory SFT 和 DPO 序列並行開關對比

為便於對比效果,我們基於第三方全尺寸開源模型粗粒度壓測了最大訓練長度,如下表 2、表 3 所示,可見 8 卡 80G 的序列並行上限已可滿足幾十至幾百 k 超長序列的需求:
圖片
表 2:第三方開源模型多尺寸 SFT 長度壓測
圖片
表 3:第三方開源模型多尺寸 DPO 長度壓測

5、總結

360 智腦開源了 360-LLaMA-Factory,支援了序列並行,僅需額外 1 個引數控制。基於 LLaMA-Factory 和 ring-flash-attention 開發,360-LLaMA-Factory 的實現模組化、效果正確且在長序列上有效。

歡迎開發者們使用和開發。在本倉庫(https://github.com/Qihoo360/360-LLaMA-Factory)下提交序列並行相關的 issue 或 PR 即可。

也歡迎研究者們,尤其是依賴長序列大模型的研究者們,在研究中使用我們的程式碼,可以這樣引用我們的工作:
@software{360-llama-factory,
author = {Haosheng Zou, Xiaowei Lv, Shousheng Jia and Xiangzheng Zhang},
title = {360-LLaMA-Factory},
url = {https://github.com/Qihoo360/360-LLaMA-Factory},
year = {2024}}

建議同時引用 LLaMA-Factory 和 ring-flash-attention 相關工作。

相關文章