再戰Transformer!原作者帶隊的Mamba 2來了,新架構訓練效率大幅提升

机器之心發表於2024-06-04

自 2017 年被提出以來,Transformer 已經成為 AI 大模型的主流架構,一直穩居語言建模方面 C 位。

但隨著模型規模的擴充套件和需要處理的序列不斷變長,Transformer 的侷限性也逐漸凸顯。一個很明顯的缺陷是:Transformer 模型中自注意力機制的計算量會隨著上下文長度的增加呈平方級增長。

幾個月前,Mamba 的出現打破了這一局面,它可以隨上下文長度的增加實現線性擴充套件。隨著 Mamba 的釋出,這些狀態空間模型 (SSM) 在中小型規模上已經實現了與 Transformers 匹敵,甚至超越 Transformers。

Mamba 的作者只有兩位,一位是卡內基梅隆大學機器學習系助理教授 Albert Gu,另一位是 Together.AI 首席科學家、普林斯頓大學電腦科學助理教授 Tri Dao。

Mamba 面世之後的這段時間裡,社群反應熱烈。可惜的是,Mamba 的論文卻慘遭 ICLR 拒稿,讓一眾研究者頗感意外。

僅僅六個月後,原作者帶隊,更強大的 Mamba 2 正式釋出了。

圖片

  • 論文地址:https://arxiv.org/pdf/2405.21060

  • GitHub 地址:https://github.com/state-spaces/mamba

  • 論文標題:Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality

總體而言,本文提出了 SSD(state space duality)框架,基於此,研究者設計了一個新的體系架構 Mamba-2,其核心層是對 Mamba 的選擇性 SSM 的改進,速度提高了 2-8 倍,同時在語言建模方面繼續與 Transformers 競爭。

Tri Dao 表示,他們構建了一個豐富的 SSD 理論框架,許多線性注意力變體和 SSM 是等效的,由此產生的模型 Mamba-2 比 Mamba-1 更好、更快。

圖片

Mamba-2 的新演算法使其能夠利用更大的狀態維度 (16 → 256),同時訓練速度更快。在需要更大狀態容量的任務上,例如 MQAR 任務,它比 Mamba-1 有了顯著的改進。

圖片

此外研究者還發現,最近新出的混合模型(Jamba、Zamba)增加了一些注意力層來提高模型質量。基於這些發現,研究者將 4-6 個注意力層與 Mamba-2 層混合,其表現優於 Transformer++ 和純 Mamba-2,因而得出注意力和 SSM 是互補的。

圖片

這項研究的貢獻概括為:

本文展示了狀態空間模型與一類稱為半可分矩陣的結構化矩陣族之間的等價性。這一聯絡是 Mamba-2 框架的核心,揭示了狀態空間模型的新屬性和演算法。

本文顯著改進了線性注意力理論,首先透過張量收縮的語言對其迴圈形式提供了一個明確的證明,然後將其推廣到一種新的結構化掩碼注意力(SMA)家族。

本文將 SSM(狀態空間模型)和 SMA(結構化掩碼注意力)聯絡起來,顯示它們有一個很大的交集,彼此是對偶的,同時具有 SSM 式的線性形式和類似注意力的二次方形式。本文還證明了任何具有快速迴圈形式的核注意方法都是 SSM。

除了內在的理論價值外,研究者所提出的框架為理解和改進序列模型開闢了廣闊的方向。

在演算法層面。所提框架為計算 SSM 提供了新的高效且易於實現的演算法。本文提出了一種基於半可分離矩陣塊分解的 SSD 演算法,該演算法利用了 SSM 線性遞推和二次對偶形式,在所有主要效率軸上獲得了最優的權衡。基於 SSD 的實現比 Mamba 的最佳化選擇性掃描實現快 2 到 8 倍,同時允許使用更大的迴圈狀態大小(是 Mamba 的 8 倍甚至更高,且幾乎不影響速度)。SSD 與最佳化過的 softmax 注意力實現(FlashAttention-2)具有高度競爭力,在序列長度 2k 時效能相當,在序列長度 16K 時速度快 6 倍。

架構設計。採用 SSM 等新架構的一個主要障礙是針對 Transformers 量身定製的生態系統,例如用於大規模訓練的硬體高效最佳化和並行技術。本文框架允許使用已建立的慣例和技術來構建 SSM 的架構設計選擇詞彙表,並進一步改進它們。

本文還對 Mamba 塊做了一些修改,這些修改允許實現張量並行,其主要思想包括引入分組值注意力 (GVA,grouped-value attention) 頭結構。

將修改後的並行 Mamba 塊與作為內部 SSM 層的 SSD 結合使用,形成了 Mamba-2 架構。研究者在與 Mamba 相同的設定中研究了 Mamba-2 的 Chinchilla 擴充套件法則,發現它在困惑度和實際執行時間方面均優於 Mamba 和 Transformer++。研究者還在 Pile 資料集上訓練了一系列 Mamba-2 模型,結果顯示 Mamba-2 在標準下游評估中匹配或超過 Mamba 和開源的 Transformers。例如,在 Pile 上訓練了 3000 億 token 的 2.7B 引數的 Mamba-2 在效能上超過了在同一資料集上訓練的 2.8B 引數的 Mamba 和 Pythia 以及 6.9B 引數的 Pythia。

系統最佳化:SSD 框架連線 SSM 和 transformer,允許利用為 transformer 開發的豐富的系統最佳化工作。

圖片

SSD

Mamba-2 的核心貢獻是新的 SSD(state space dual)層。SSD 層可以被定義為選擇性 SSM 的特例。與 Mamba 相比,Mamba-2 的改動會略微降低表達能力,但卻顯著提高了訓練效率,特別是允許在現代加速器上使用矩陣乘法單元。

圖片

圖片

SSD 層的對偶注意力:

圖片

除了最新的 SSD 層,研究者也對 Mamba 的神經網路架構做了一些小的改變,Mamba-2 架構如下所示。

圖片

Mamba-2 在網路架構上的主要變化是從順序生成變為並行生成 SSM 引數,並且 Mamba-2 更適合張量並行等擴充套件方法。

透過提供狀態空間模型的顯式矩陣變換形式,研究團隊揭示了理解和使用它們的新方法。從計算的角度來看,任何計算狀態空間模型前向傳播的方法都可以看作是半可分離矩陣上的矩陣乘法演算法。半可分離矩陣視角為 SSD 提供了一個視角,其中雙重模式分別指的是線性時間半可分離矩陣乘法演算法和二次時間樸素矩陣乘法。

圖片

研究團隊定義了結構化狀態空間模型和結構化注意力,討論了它們的屬性,並表明它們都有二次演算法和線性演算法。

圖片

圖片

自最初的 Mamba 論文研究了合成任務 —— 如:合成複製和歸納 Head 以來,許多後續工作開始研究更難的關聯回憶任務。由 Zoology 和 Based 系列工作引入的 MQAR(multi-query associative recall)任務已成為事實上的標準。

圖片

透過執行一個比文獻中通常報告的版本要難得多的任務,該團隊發現 Mamba-2 明顯優於 Mamba-1,而改善效能的一個原因是狀態大小(比 Mamba-1 大約 16 倍)。

在這篇文章中,作者深入探討了模型背後的理論。

從兩個完全不同的角度推匯出 SSD 的「對偶性」:

  • 一個從 SSM 的角度出發;

  • 另一個從注意力機制的角度出發。

SSD 框架提供了狀態空間模型、注意力機制和結構化矩陣之間豐富的聯絡。

雖然 SSD 模型可以被視為框架內每個分支的具體例項,但 SSD 框架本身更加通用,為未來的工作開闢了許多方向。

圖片

SSD 框架(紅色,藍色):狀態空間模型(即半可分矩陣)和結構化掩碼注意力機制包含了大量高效的序列模型。它們的交集是 SSD 模型(紫色)。

SSD 演算法

通常,矩陣乘法(matmul)的 FLOPs 速度要比非矩陣乘法 FLOPs 快得多(高達 16 倍):A100 GPU 具有 312 TFLOPS 的 BF16 矩陣乘法效能,但只有 19 TFLOPS 的 FP32 算術效能,而 H100 具有 989 TFLOPS 的 BF16 矩陣乘法效能,但只有 67 TFLOPS 的 FP32 算術效能。

Mamba-2 的主要目標之一是「利用張量核心加速 SSM」。

在繫結引數並引入 Head 結構後,Mamba-1 中的 SSM 變成了 SSD,這是一種更具限制性的形式,具有類似注意力的公式。並且由於 SSD 連線 SSM 和結構化矩陣,計算 SSM 的高效演算法直接對應於「token-mixing」或「sequence-mixing」矩陣 M 的不同分解。

圖片

因此,可以透過尋找替代的矩陣乘法方式,例如透過各種方式對其進行分解,從而建立計算 SSM 的新演算法。

透過精心選擇塊大小,對這個矩陣進行簡單塊分解,就可以集 SSD 線性遞迴和二次注意力對偶形式的兩種優勢於一身。

而這也就是 SSD 演算法的起源,它有 4 個步驟,並且對於這個演算法有兩種完全不同的詮釋。

SSD 演算法:分塊矩陣分解

首先將半可分 SSM 矩陣劃分為大小為 Q×Q 的塊,然後,利用半分矩陣的性質來分解每個低秩的非對角塊:

  1. (橙色)每個對角塊是一個更小的半可分矩陣,可以以喜歡的方式計算這個乘法,特別是使用 SSD 的二次(類似注意力機制)形式。

  2. (綠色)總共有 T/Q 個不同的綠色塊,透過批處理矩陣乘法來計算。

  3. (黃色)注意,黃色項本身是一個 1 - 半可分矩陣,這一步等價於對某些修改後的 A 因子的 SSM 掃描。

  4. (藍色)與綠色類似,透過批處理矩陣乘法來計算。

SSD 演算法:分塊和狀態傳遞

該演算法的另一種詮釋涉及「推理 SSM 如何在實際序列上進行操作」。

首先將輸入序列分割成大小為 Q 的塊,步驟可以分為:

  1. 分塊內部輸出:計算每個塊的區域性輸出(假設初始狀態(對於塊)為 0,則每個塊的輸出是多少?)

  2. 塊狀態:計算每個塊的最終狀態(假設初始狀態(對於塊)為 0,則每個塊的最終狀態是多少?)

  3. 傳遞狀態:計算所有塊的最終狀態的遞迴 - 使用任何所需的演算法,例如並行或順序掃描(考慮到所有先前輸入,每個塊的實際最終狀態是多少?)

  4. 輸出狀態:對於每個塊,根據其真實的初始狀態(在步驟 3 中計算),僅從初始狀態得出的輸出計算貢獻

可以看到,大部分演算法(步驟 1、2 和 4)利用了矩陣乘法(因此利用了張量核心),而且可以平行計算。

只有步驟 3 需要掃描,但它只操作一個非常短的序列,通常只需要很少時間。

系統及擴充套件最佳化

張量並行

圖片

使用張量並行對 Mamba-1 進行大規模訓練的一項困難是,每層都需要 2 次 all-reduce,而在 Transformer 中,每個注意力或 MLP 層只需 1 次 all-reduce。這是因為 SSM 的一些引數是內部啟用的函式,而不是層的輸入函式。在 Mamba-2 中,由於採用了「並行投影」結構,所有 SSM 引數都是層輸入的函式,因此可以輕鬆地將張量並行應用於輸入投影:將輸入投影和輸出投影矩陣分割成 2、4、8 個碎片,具體取決於張量並行度。研究者使用 grouped norm,分組數除以張量並行度,這樣每個 GPU 都能單獨完成歸一化。這些變化導致每層只需 1 次 all-reduce,而不是 2 次。

序列並行

圖片

在對超長序列進行訓練時,可能需要沿著序列長度進行分割,並將不同部分分配給不同的裝置。序列並行主要有兩種形式:對於殘差和歸一化操作,用 reduce-scatter、殘差 + 歸一化、然後 all-gather,取代張量並行中的 all-reduce。由於 Mamba-2 使用與 Transformer 相同的殘差和歸一化結構,因此這種形式的序列並行無需修改即可直接應用。對於注意力或 SSM 操作,又稱上下文並行(CP)。對於注意力,可以使用環形注意力沿序列維度進行分割。對於 Mamba-2,SSD 框架再次提供了幫助:使用相同的蒯分解,可以讓每個 GPU 計算其本地輸出和最終狀態,然後在更新每個 GPU 的最終輸出之前,在 GPU 之間傳遞狀態(使用傳送 / 接收通訊原語)。

實驗結果

該研究在 MQAR 的一種具有挑戰性的版本上,使用更難的任務、更長的序列和更小的模型進行了對比實驗。基線包括標準的多頭 softmax 注意力以及 Based 架構,實驗結果如圖 8 所示。

圖片

下表顯示了 Mamba-2 在一系列下游零樣本評估任務上的效能:

圖片

感興趣的讀者可以閱讀論文原文,瞭解更多研究內容。

相關文章