極長序列、極快速度:面向新一代高效大語言模型的LASP序列並行

机器之心發表於2024-04-16

圖片

AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com。

從國際頂流 GPT-4 128K、Claude 200K 到國內「當紅炸子雞」支援 200 萬字上下文的 Kimi Chat,大語言模型(LLM)在長上下文技術上不約而同地捲起來了。當全世界最聰明的頭腦都在卷一件事的時候,這件事的重要性和難度就自然不言自明。

極長的上下文可以極大擴充大模型的生產力價值。隨著 AI 的普及,使用者已經不再滿足於調戲大模型幾個腦筋急轉彎,使用者開始渴望利用大模型來真正提高生產力。畢竟從前花一週憋出來的 PPT,現在只需要餵給大模型一串提示詞和幾份參考文件就分分鐘生成出來,打工人誰能不愛呢?

新型高效序列建模方法比如:Lightning Attention (TransNormerLLM), State Space Modeling (Mamba), Linear RNN (RWKV, HGRN, Griffin) 等最近成為炙手可熱的研究方向。研究人員渴望透過改造已經 7 歲高齡的 Transformer 架構,獲得效能與之旗鼓相當,但複雜度僅為線性的新型架構。這類方法專注於模型架構設計,並提供了基於 CUDA 或 Triton 的硬體友好實現,使其能夠像 FlashAttention 一樣在單卡 GPU 內部高效計算。

與此同時,另一個長序列訓練的殺手鐧:序列並行獲得了越來越多的關注。透過把長序列在序列維度切分為多個等分短序列,再將短序列分散至不同 GPU 卡並行訓練,再輔以卡間通訊便達到了序列並行訓練的效果。從最早出現的 Colossal-AI 序列並行、到 Megatron 序列並行、再到 DeepSpeed Ulysses、以及近期的 Ring Attention,研究人員不斷設計更加優雅高效的通訊機制以提升序列並行的訓練效率。當然這些已知方法全部是為傳統注意力機制設計的,本文中我們稱之為 Softmax Attention。這些方法也已經有各路大神做了精彩分析,本文不過多探討。

那麼問題來了。如何讓新型高效序列建模方法實現序列並行,從而跨越單卡 GPU 視訊記憶體限制實現真正意義的無限序列長度(當然你得有無限 GPU)的高效大語言模型訓練,成為了一個開放的問題。已經成熟的序列並行方法如 DeepSpeed Ulysses, Megatron-SP 當然可以應用線上性序列建模方法上,但以 Softmax Attention 為設計藍本的它們註定天生不是最優解。
圖片
  • 論文標題:Linear Attention Sequence Parallelism
  • 論文地址:https://arxiv.org/abs/2404.02882
  • LASP程式碼地址:https://github.com/OpenNLPLab/LASP

本文即將介紹的 LASP 便應運而生。來自上海人工智慧實驗室的研究人員提出了 Linear Attention Sequence Parallelism (LASP) 方法以充分利用 Linear Attention 的線性右乘特性實現高效的序列平行計算。在 128 卡 A100 80G GPU、TransNormerLLM 1B 模型、FSDP backend 的配置下,LASP 可以最高將序列長度擴充套件至 4096K,即 4M。與成熟的序列並行方法相比,LASP 可訓練的最長序列長度是 Megatron-SP 的 8 倍、DeepSpeed Ulysses 的 4 倍,速度則分別快了 136% 和 38%。

值得注意的是,雖然方法的名字包含 Linear Attention,LASP 並不侷限於 Linear Attention 方法,而是可以廣泛應用於包括 Lightning Attention (TransNormerLLM), State Space Modeling (Mamba), Linear RNN (RWKV, HGRN, Griffin) 等在內的線性序列建模方法。

圖片

LASP 方法介紹

為了充分理解 LASP 的思路,讓我們先回顧下傳統 Softmax Attention 的計算公式:O=softmax ((QK^T)⊙M) V,其 Q, K, V, M, O 分別為 Query, Key, Value, Mask 和 Output 矩陣,這裡的 M 在單向任務(如 GPT)中是一個下三角的全 1 矩陣,在雙向任務(如 BERT)中則可以忽略,即雙向任務沒有 Mask 矩陣。我們下面將 LASP 拆為四點進行解釋:

Linear Attention 原理

Linear Attention 可以視為 Softmax Attention 一種變體。Linear Attention 去除了計算成本高昂的 Softmax 運算元,Attention 的計算公式可以寫為 O=((QK^T)⊙M) V 的簡潔形式。但由於單向任務中 Mask 矩陣 M 的存在,使得該形式依然只能進行左乘計算(即先計算 QK^T),從而不能獲得 O (N) 的線性複雜度。但對於雙向任務,由於沒有 Mask 矩陣的存在,其計算公式可以進一步簡化為 O=(QK^T) V。Linear Attention 的巧妙之處在於,僅僅利用簡單的矩陣乘法結合律,其計算公式就可以進一步轉化為:O=Q (K^T V),這種計算形式被稱之為右乘,可見 Linear Attention 在這種雙向任務中可以達到誘人的 O (N) 複雜度!

圖片

LASP 資料分發

LASP 首先將長序列資料從序列維度切分為多個等分的子序列,再將子序列分散傳送至序列並行通訊組內的所有 GPU,使得每張 GPU 上各有一段子序列,以供後續序列並行的計算使用。

圖片

LASP 核心機制

隨著 decoder-only 的類 GPT 形式的模型逐漸成為 LLM 的事實標準,LASP 的設計充分考慮了單向 Casual 任務的場景。由切分後子序列 Xi 計算而來的便是按照序列維度切分的 Qi, Ki, Vi,每一個索引 i 對應一個 Chunk 和一個 Device(即一張 GPU)。由於 Mask 矩陣的存在,LASP 作者巧妙地將各個 Chunk 對應的 Qi, Ki, Vi 區分為兩種,即:Intra-Chunk 和 Inter-Chunk。其中 Intra-Chunk 為 Mask 矩陣分塊後對角線上的 Chunk,可以認為仍然有 Mask 矩陣的存在,依然需要使用左乘;Inter-Chunk 則為 Mask 矩陣非對角線上的 Chunk,可以認為沒有 Mask 矩陣的存在,可以使用右乘;顯然,當切分的 Chunk 越多時,對角線上的 Chunk 佔比越少,非對角線上的 Chunk 佔比越多,可以利用右乘實現線性複雜度 Attention 計算的 Chunk 就越多。其中,對於右乘的 Inter-Chunk 的計算,前向計算時每個裝置需要使用點對點通訊 Recive 上一個裝置的 KV,並 Send 自己的更新後的 KV 給下一個裝置。反向計算時則正好相反,只是 Send 和 Recive 的物件變為了 KV 的梯度 dKV。其中前向計算過程如下圖所示:

圖片

LASP 程式碼實現

為了提高 LASP 在 GPU 上的計算效率,作者對 Intra-Chunk 和 Inter-Chunk 的計算分別進行了 Kernel Fusion,並將 KV 和 dKV 的更新計算也融合到了 Intra-Chunk 和 Inter-Chunk 計算中。另外,為了在反向傳播過程中避免重新計算啟用 KV,作者選擇在前向傳播計算後立即將其儲存在 GPU 的 HBM 中。在隨後的反向傳播過程中,LASP 直接訪問 KV 以供使用。需要注意的是,儲存在 HBM 中的 KV 大小為 d x d,完全不受序列長度 N 的影響。當輸入序列長度 N 較大時,KV 的記憶體佔用變得微不足道。在單張 GPU 內部,作者實現了由 Triton 實現的 Lightning Attention 以減少 HBM 和 SRAM 之間的 IO 開銷,從而加速單卡 Linear Attention 計算。

想要了解更多細節的讀者,可以閱讀論文中的 Algorithm 2(LASP 前向過程)和 Algorithm 3(LASP 反向過程),以及文中詳細的推導過程。

通訊量分析

LASP 演算法中需要注意前向傳播需要在每個 Linear Attention 模組層進行 KV 啟用的通訊。通訊量為 Bd^2/h,其中 B 是 batch 大小,h 是頭數。相比之下,Megatron-SP 在每個 Transformer 層中的兩個 Layer Norm 層之後分別使用了一次 All-Gather 操作,並在 Attention 和 FFN 層之後分別使用了一次 Reduce-Scatter 操作,這導致其通訊量為 2BNd + 4BNd/T,其中 T 為序列並行維度。DeepSpeed-Ulysses 使用了 All-to-All 集合通訊操作來處理每個 Attention 模組層的輸入 Q, K, V 和輸出 O,導致通訊量為 4BNd/T。三者的通訊量對比如下表所示。其中 d/h 是頭維度,通常設定為 128。在實際應用中,當 N/T>=32 時,LASP 便能夠實現最低的理論通訊量。此外,LASP 的通訊量不受序列長度 N 或子序列長度 C 的影響,這對於跨大型 GPU 叢集的極長序列平行計算是一個巨大的優勢。

圖片

Data-Sequence 混合並行

資料並行(即 Batch-level 的資料切分)已經是分散式訓練的常規操作,在原始資料並行(PyTorch DDP)的基礎上,已經進化出了更節省視訊記憶體的切片式資料並行,從最初的 DeepSpeed ZeRO 系列到 PyTorch 官方支援的 FSDP,切片式資料並行已經足夠成熟並被越來越多使用者使用。LASP 作為 Sequence-level 的資料切分方法,可以能夠和包括 PyTorch DDP, Zero-1/2/3, FSDP 在內的各種資料並行方法相容使用。這對 LASP 的使用者來說無疑是好訊息。

精度實驗

在 TransNormerLLM (TNL) 和 Linear Transformer 上的實驗結果表明,LASP 作為一種系統最佳化方法能夠和各種 DDP backends 結合,並均能達到與 Baseline 持平的效能。

圖片

可擴充套件性實驗

得益於高效的通訊機制設計,LASP 可以輕鬆擴充套件至上百卡 GPU,並保持很好的可擴充套件性。

圖片

速度對比實驗

與成熟的序列並行方法 Megatron-SP 和 DeepSpeed-Ulysses 對比,LASP 可訓練的最長序列長度是 Megatron-SP 的 8 倍、DeepSpeed-Ulysses 的 4 倍,速度則分別快了 136% 和 38%。

圖片

結語

為了方便大家試用,作者已經提供了一個即裝即用的 LASP 程式碼實現,無需下載資料集和模型,只需 PyTorch 分分鐘體驗 LASP 的極長極快序列並行能力。

程式碼傳送門:https://github.com/OpenNLPLab/LASP

相關文章