RNN效率媲美Transformer,谷歌新架構兩連發:同等規模強於Mamba

机器之心發表於2024-03-04

去年 12 月,新架構 Mamba 引爆了 AI 圈,向屹立不倒的 Transformer 發起了挑戰。如今,谷歌 DeepMind「Hawk 」和「Griffin 」的推出為 AI 圈提供了新的選擇。


這一次,谷歌 DeepMind 在基礎模型方面又有了新動作。

我們知道,迴圈神經網路(RNN)在深度學習自然語言處理研究的早期發揮了核心作用,並在許多應用中取得了實功,包括谷歌第一個端到端機器翻譯系統。不過近年來,深度學習和 NLP 都以 Transformer 架構為主,該架構融合了多層感知器(MLP)和多頭注意力(MHA)。

Transformer 已經在實踐中實現了比 RNN 更好的效能,並且在利用現代硬體方面也非常高效。基於 Transformer 的大語言模型在從網路收集的海量資料集上進行訓練,取得了顯著的成功。

縱然取得了很大的成功,但 Transformer 架構仍有不足之處,比如由於全域性注意力的二次複雜性,Transformer 很難有效地擴充套件到長序列。此外,鍵值(KV)快取隨序列長度線性增長,導致 Transformer 在推理過程中變慢。這時,迴圈語言模型成為一種替代方案,它們可以將整個序列壓縮為固定大小的隱藏狀態,並迭代更新。但若想取代 Transformer,新的 RNN 模型不僅必須在擴充套件上表現出相當的效能,而且必須實現類似的硬體效率。

在谷歌 DeepMind 近日的一篇論文中,研究者提出了 RG-LRU 層,它是一種新穎的門控線性迴圈層,並圍繞它設計了一個新的迴圈塊來取代多查詢注意力(MQA)。

他們使用該迴圈塊構建了兩個新的模型,一個是混合了 MLP 和迴圈塊的模型 Hawk另一個是混合了 MLP 與迴圈塊、區域性注意力的模型 Griffin

圖片

  • 論文標題:Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
  • 論文連結:https://arxiv.org/pdf/2402.19427.pdf

研究者表示,Hawk 和 Griffin 在 held-out 損失和訓練 FLOPs 之間表現出了冪律縮放,最高可以達到 7B 引數,正如之前在 Transformers 中觀察到的那樣。其中 Griffin 在所有模型規模上實現了比強大 Transformer 基線略低的 held-out 損失。

圖片

研究者針對一系列模型規模、在 300B tokens 上對 Hawk 和 Griffin 進行了過度訓練,結果顯示,Hawk-3B 在下游任務的效能上超越了 Mamba-3B,儘管訓練的 tokens 數量只有後者的一半。Griffin-7B 和 Griffin-14B 的效能與 Llama-2 相當,儘管訓練的 tokens 數量只有後者的 1/7。

此外,Hawk 和 Griffin 在 TPU-v3 上達到了與 Transformers 相當的訓練效率。由於對角 RNN 層受記憶體限制,研究者使用了 RG-LRU 層的核心來實現這一點。

同時在推理過程中,Hawk 和 Griffin 都實現比 MQA Transformer 更高的吞吐量,並在取樣長序列時實現更低的延遲。當評估的序列比訓練中觀察到的更長時,Griffin 的表現比 Transformers 更好,並且可以有效地從訓練資料中學習複製和檢索任務。不過當在未經微調的情況下在複製和精確檢索任務上評估預訓練模型時,Hawk 和 Griffin 的表現不如 Transformers。

共同一作、DeepMind 研究科學家 Aleksandar Botev 表示,混合了門控線性迴圈和區域性注意力的模型 Griffin 保留了 RNN 的所有高效優勢和 Transformer 的表達能力,最高可以擴充套件到 14B 引數規模。

圖片 來源:https://twitter.com/botev_mg/status/1763489634082795780

Griffin 模型架構

Griffin 所有模型都包含以下組成部分:(i) 一個殘差塊,(ii) 一個 MLP 塊,(iii) 一個時間混合塊。所有模型的 (i) 和 (ii) 都是相同的,但時間混合塊有三個:全域性多查詢注意(MQA)、區域性(滑動視窗)MQA 和本文提出的迴圈塊。作為迴圈塊的一部分,研究者使用了真實門控線性迴圈單元(RG-LRU)—— 一種受線性迴圈單元啟發的新型迴圈層。

如圖 2(a)所示,殘差塊定義了 Griffin 模型的全域性結構,其靈感來自 pre-normTransformer。在嵌入輸入序列後,研究者將其透過 𝑁 這樣的塊(𝑁 表示模型深度),然後應用 RMSNorm 生成最終啟用。為了計算 token 機率,應用了最後的線性層,然後是 softmax。該層的權重與輸入嵌入層共享。

圖片

迴圈模型,縮放效率媲美 Transformer

縮放研究為如何調整模型的超引數及其在縮放時的行為提供了重要見解。

研究者定義了本研究中進行評估的模型,並提供了高達和超過 7B 引數的縮放曲線,並評估了模型在下游任務中的效能。

他們考慮了 3 個模型系列:(1)MQA-Transformer 基線;(2)Hawk:純 RNN 模型;(3)Griffin:混合模型,它將迴圈塊與區域性注意力混合在一起。附錄 C 中定義了各種規模模型的關鍵模型超引數

Hawk 架構使用了與 Transformer 基線相同的殘差模式和 MLP 塊,但研究者使用了帶有 RG-LRU 層的迴圈塊作為時序混合塊,而不是 MQA。他們將迴圈塊的寬度擴大了約 4/3 倍(即𝐷_𝑅𝑁𝑁 ≈4𝐷/3),以便在兩者使用相同的模型維度 𝐷 時,與 MHA 塊的引數數量大致匹配。

Griffin。與全域性注意力相比,迴圈塊的主要優勢在於它們使用固定的狀態大小來總結序列,而 MQA 的 KV 快取大小則與序列長度成正比增長。區域性注意力具有相同的特性,而將迴圈塊與區域性注意力混合則可以保留這一優勢。研究者發現這種組合極為高效,因為區域性注意力能準確模擬最近的過去,而迴圈層則能在長序列中傳遞資訊。

Griffin 使用了與 Transformer 基線相同的殘差模式和 MLP 塊。但與 MQA Transformer 基線和 Hawk 模型不同的是,Griffin 混合使用了迴圈塊和 MQA 塊。具體來說,研究者採用了一種分層結構,將兩個殘差塊與一個迴圈塊交替使用,然後再使用一個區域性(MQA)注意力塊。除非另有說明,區域性注意力視窗大小固定為 1024 個 token。

主要縮放結果如圖 1(a)所示。三個模型系列都是在從 1 億到 70 億個引數的模型規模範圍內進行訓練的,不過 Griffin 擁有 140 億引數的版本。

在下游任務上的評估結果如表 1 所示:

圖片

Hawk 和 Griffin 的表現都非常出色。上表報告了 MMLU、HellaSwag、PIQA、ARC-E 和 ARC-C 的特徵歸一化準確率,同時報告了 WinoGrande 的絕對準確率和部分評分。隨著模型規模的增大,Hawk 的效能也得到了顯著提高,Hawk-3B 在下游任務中的表現要強於 Mamba-3B,儘管其訓練的 token 數量只有 Mamba-3B 的一半。Griffin-3B 的效能明顯優於 Mamba-3B,Griffin-7B 和 Griffin-14B 的效能可與 Llama-2 相媲美,儘管它們是在少了近 7 倍的 token 上訓練出來的。Hawk 能與 MQA Transformer 基線相媲美,而 Griffin 的表現則超過了這一基線。

在端側高效訓練迴圈模型

在開發和擴充套件模型時,研究者遇到了兩大工程挑戰。首先,如何在多臺裝置上高效地分片處理模型。第二,如何有效地實現線性迴圈,以最大限度地提高 TPU 的訓練效率。本文討論了這兩個難題,然後對 Griffin 和 MQA 基線的訓練速度進行實證比較。

研究者比較了不同模型大小和序列長度的訓練速度,以研究本文模型在訓練過程中的計算優勢。對於每種模型大小,都保持每批 token 的總數固定不變,這意味著隨著序列長度的增加,序列數量也會按比例減少。

圖 3 繪製了 Griffin 模型與 MQA 基線模型在 2048 個序列長度下的相對執行時間。

圖片

推理速度

LLM 的推理由兩個階段組成。「預填充 」階段是接收並處理 prompt。這一步實際上是對模型進行前向傳遞。由於 prompt 可以在整個序列中並行處理,因此在這一階段,大多數模型操作都是計算受限的因此,研究者預計 Transformers 模型和迴圈模型在預填充階段的相對速度與前文討論的那些模型在訓練期間的相對速度相似。

預填充之後是解碼階段,在這一階段,研究者從模型中自迴歸地採 token。如下所示,尤其是對於序列長度較長時,注意力中使用的鍵值(KV)快取變得很大,迴圈模型在解碼階段具有更低的延遲和更高的吞吐量。

評估推斷速度時有兩個主要指標需要考慮。第一個是延遲,它衡量在特定批次大小下生成指定數量 token 所需的時間。第二個是吞吐量,它衡量在單個裝置上取樣指定數量 token 時每秒可以生成的最大 token 數。因為吞吐量由取樣的 token 數乘以批次大小除以延遲得出,所以可以透過減少延遲或減少記憶體使用以在裝置上使用更大的批次大小來提高吞吐量。對於需要快速響應時間的實時應用來說,考慮延遲是有用的。吞吐量也值得考慮,因為它可以告訴我們在給定時間內可以從特定模型中取樣的最大 token 數量。當考慮其他語言應用,如基於人類反饋的強化學習(RLHF)或評分語言模型輸出(如 AlphaCode 中所做的)時,這個屬性是有吸引力的,因為能夠在給定時間內輸出大量 token 是一個吸引人的特性。

在此,研究者研究了引數為 1B 的模型推理結果。在基線方面,它們與 MQA Transformer 進行了比較,後者在推理過程中的速度明顯快於文獻中常用的標準 MHA 變換器。研究者比較的模型有:i) MQA 變換器,ii) Hawk 和 iii) Griffin。為了比較不同的模型,我們報告了延遲和吞吐量。

如圖 4 所示,研究者比較了批次大小為 16、空預填充和預填充 4096 個 token 的模型的延遲。

圖片

圖 1(b)中比較了相同模型在空提示後分別取樣 512、1024、2048 和 4196 個 token 時的最大吞吐量(token / 秒)。

長上下文建模

本文還探討了 Hawk 和 Griffin 使用較長上下文來改進下一個 token 預測的有效性,並研究它們在推理過程中的外推能力。此外還探討了 Griffin 在需要複製和檢索能力的任務中的表現,既包括在此類任務中訓練的模型,也包括在使用預訓練的語言模型測試這些能力時的表現。

從圖 5 左側的曲線圖中,可以觀察到,在一定的最大長度範圍內,Hawk 和 Griffin 都能在更長的上下文中提高下一個 token 的預測能力,而且它們總體上能夠推斷出比訓練時更長的序列(至少 4 倍)。尤其是 Griffin,即使在區域性注意力層使用 RoPE 時,它的推理能力也非常出色。

圖片

如圖 6 所示,在選擇性複製任務中,所有 3 個模型都能完美地完成任務。在比較該任務的學習速度時, Hawk 明顯慢於 Transformer,這與 Jelassi et al. (2024) 的觀察結果類似,他們發現 Mamba 在類似任務上的學習速度明顯較慢。有趣的是,儘管 Griffin 只使用了一個區域性注意力層,但它的學習速度幾乎沒有減慢,與 Transformer 的學習速度不相上下。

圖片

更多細節,請閱讀原論文。

相關文章