在當前 AI 領域,大語言模型採用的主流架構是 Transformer。不過,隨著 RWKV、Mamba 等架構的陸續問世,出現了一個很明顯的趨勢:在語言建模困惑度方面與 Transformer 較量的迴圈大語言模型正在快速進入人們的視線。
令人興奮的是,這些架構在推理期間使用了恆定量的記憶體。不過,受制於有限的記憶體,迴圈語言模型(LM)無法記憶並使用長上下文中的所有資訊,這導致了上下文學習(in-context learning,ICL)質量的不佳。因此,獲得高效大語言模型的關鍵挑戰在於選擇儲存或者丟棄哪些資訊。
在最近的論文《Just read twice: closing the recall gap for recurrent language models》中,來自史丹佛大學、布法羅大學的研究者透過簡單觀察發現,資料在推理期間湧入迴圈語言模型的排序極大地影響了在有限記憶體中預測儲存哪些資訊的難度。
我們假設根據文件 D(比如伽利略・伽利萊的詳細維基百科)來提問:伽利略是什麼時候搬到的佛羅倫薩?這時,如果提示遵循了 [Q, D] 的排序,則模型只需要記住文件 D 中的一個事實即可。相反,如果提示遵循了 [D, Q] 的排序,則模型需要記住所有事實。如下圖 1(左)所示。
因此,本文首先從理論上形式化了資料排序如何影響記憶體需求,然後提出兩種方法來減輕對資料排序的依賴,分別是 Just-read-twice(JRT)提示策略和 JRT 迴圈架構。本文主要分為以下幾個部分展開:
理解資料排序的作用。研究者得出的第一個洞見是:記憶問題的 hardness 要降低到與設定剝離(set disjointness,SD)相同,這是通訊複雜度理論中持續數十年的最典型問題。SD 要求一種流演算法(比如迴圈模型)來決定上下文中提供的輸入集是否剝離:
理論分析和實驗結果表明,第一個集 | A | 掌控了求解 SD 所需的記憶體。因果模型需要儲存 A 中的所有元素以與 B 中的元素進行比較。這表明了,使用上下文中的「正確資料排序」(如將最小 min (|A|, |B|) 的集放在首位)將有助於記憶體受限的模型。更進一步,觀察到上下文非因果邏輯的模型可在空間最小的 (|A|, |B|) 中求解 SD,而無需考慮資料排序。
其次是利用「正確的」排序。本文提出了一種非常簡單的 JRT-Prompt 策略,在模型生成答案之前在上下文中將資訊重複多次(如上圖 1 右所示)。在第二以及更多輪次中,語言模型在決定儲存哪些資訊時要以完整的上下文為條件,從而有效避免了將資料排序「歸正」的問題。
結果表明,JRT-Prompt 在 16 個已有迴圈語言模型和 6 項 ICL 任務上,實現了平均 11.0 ± 1.3 百分點的提升,而吞吐量是 FlashAttention-2(長度 32k、批大小 16)的 11.9 倍。JRT-Prompt 雖然增加了上下文長度,但漸進來看仍然比注意力更加地計算和記憶體高效。
超越因果模型。本文提出了 JRT-RNN,它的靈感來源於簡單的 Prefix-LM 編碼器解碼器架構設計。大多數的上下文學習輸入包含兩部分內容,分別是輸入的提示(上下文、指令)和作為輸出的模型生成文字。在 Prefix-LM 架構中,LM 並沒有遵循因果邏輯地處理提示區域,而對輸出進行了因果解碼,其中在因果區域僅使用了標準的下一個 token 預測損失,以及非因果區域上的損失。
不過遺憾的是,此前 Prefix-LM 模型的訓練方法取得的成功有限,並使用了低效的 Transformer 主幹。因此本文透過一些簡單的改變來提高質量和效率,包括改進訓練損失並使用稱之為「Prefix Linear Attention,PLA」 的線性注意力公式。研究者發現,使用他們的 IO 感知實現,JRT-RNN 在 360m 和 1.3b 引數設定下,分別可以提供 13.7 和 6.9 百分點的平均質量改進,吞吐量是 FA2 的 19.2 倍。
論文地址:https://arxiv.org/pdf/2407.05483
專案主頁:https://github.com/HazyResearch/prefix-linear-attention
JRT-Prompt 方法概覽
上下文學習任務以 (C, Q, Y) 作為輸入,其中 C 為一些上下文來源(如文件或程式碼儲存庫),Q 為給定上下文時對模型的一些問題或請求,Y 為答案。對於使用自迴歸 LM A 的標準上下文學習,研究者輸入 C 和 Q,並根據正確的完成情況 Y 來評估生成的輸出 Yˆ = A (C, Q)。
JRT-Prompt 是一種極其簡單的方法,在提示模型輸出答案之前會在上下文中重複提示中的資訊(如問題和文件),例如下圖 1 右的 Yˆ = A (C, Q, C, Q)。因此,在上下文第二次出現時,模型根據完整的上下文來決定儲存哪些資訊。
此外,JRT-Prompt 可以與現成的 LLM 一起使用。研究者在零樣本提示下,在一系列記憶密集型上下文任務上評估了以下 LM:
Based 預訓練 LM,引數規模為 1.3B,在 Pile 的 10 − 50B 個 token 上進行訓練;
Mamba 預訓練的 LM,引數規模為 130M、370M、1.4B 和 2.8B,在 Pile 的 300B 個 token 上進行訓練;
Gated Linear Attention 預訓練的 LM,引數規模為 1.3B 和 2.7B,在 SlimPajama 資料集的 100B 個 token 上進行訓練;
Mamba-2 預訓練的 LM,引數規模為 130M、370M、1.3B 和 2.7B,在 Pile 的 300B 個 token 上進行訓練。
結果如下表 1 所示,透過增加狀態(state)大小,研究者發現 JRT-Prompt 方法在各個模型和任務上平均帶來了 11.0 ± 1.3 百分點的效能提升,利用該方法的 Based 模型平均優於利用標準提示的 Transformer 模型。
他們還發現,JRT-Prompt 可以使 Transformer 模型受益,並且該方法在一些任務上(附錄 2)比少樣本學習更加有效。值得注意的是,Springer 等人在論文《Repetition improves language model embeddings》中提出使用自迴歸 Transformer 模型來重複上下文以實現生成嵌入的目的,本文的研究結果也類似。研究者專注於亞二次架構和上下文學習任務。
JRT-Prompt 雖然由於重複而增加了上下文長度,但是其使用的亞二次迴圈架構仍比使用二次 Transformer 模型更高效。研究者發現,在序列長度 N = 32768、批大小為 16 時,使用 JRT-Prompt(序列長度 2N)在英偉達 H100 上提供的吞吐量是 FlashAttention-2(序列長度 N)的 11.9 倍。
JRT-RNN:編碼器 - 解碼器迴圈架構
JRT-RNN 的靈感來自於 Prefix-LMs,但側重於擴充套件質量 - 效率權衡空間的帕累託邊界(Pareto frontier)。為了提高質量,JRT-RNN 在編碼器端使用了單獨的 k_e 和 v_e 對映,在解碼器端使用了 k_d 和 v_d 對映。雖然 Prefix LM 模型對編碼器和解碼器區域使用了共享對映權重,但研究者發現使用兩組對映可以提高質量。
為了提高效率,JRT-RNN 為編碼器使用了非因果線性注意力,而為解碼器使用標準因果線性注意力。研究者稱為 Prefix Linear Attention(PLA)(圖 1 右),公式如下:
JRT-RNN 訓練目標。Prefix LMs 通常不計算非因果區域的損失,而 JRT-RNN 將下一個 token 預測與掩碼語言建模(MLM)目標進行了結合。並且對於新增的 MLM 目標,研究者用一個 [MASK] token 替換了來自編碼器區域 {u_1, ..., u_M} 的比例為 P 的 tokens,並在預測原始 token 時測量了交叉熵損失。
損失如下:
實驗結果
在實驗中,研究者評估了 JRT-RNN 在以下三個指標上的質量和效率:
上下文學習質量
整體語言建模
生成
上下文學習質量
如下表 2 所示,研究者發現,JRT-RNN 在引數為 360M(30B tokens)時比僅解碼器的基線(Based)平均高出 13.7 個百分點,在引數為 1.3B(50B tokens)時平均高出 6.9 個百分點。
同時,JRT-RNN 在引數為 360M 和 1.3B 時與 Transformer++ 的差距分別縮小到了 0.5 個百分點和 1.9 個百分點之內。
在下表 3 中,研究者比較了當 prefill 長度 l 小於編碼器長度 M 時,JRT-RNN 與同類推理策略的表現。
整體自然語言理解
根據以往研究,研究者進一步將困惑度分為了兩組:聯想記憶「AR slice」包括了被稱為「AR hits」的 tokens,它們需要模型按照順序執行記憶以正確地預測下一個 token;而「Other slice」包含剩餘的 tokens(如記憶的知識)。
對於記憶頻率,JRT-RNN 在「AR slice」表現出色。對於訓練期間不常見的二元組(即不太可能在模型引數中被記住的),JRT-RNN 的困惑度相對於 Based 和 Mamba 這兩個強大的因果迴圈基線有所改善。
對於記憶距離,在「AR slice」中,JRT-RNN 與僅解碼器基線之間的差距隨著上下文中重複二元組的增加而擴大。這也進一步證明了 JRT-RNN 可以幫助完成更長的上下文記憶任務。
非記憶頻率。對於訓練期間很少見到的二元組的非記憶「Other slice」,JRT-RNN 的困惑度比僅解碼器的 LM 更差。這是意料之中的結果,因為 JRT-RNN 計算了僅解碼器 LM 的 65% tokens 的損失。
我們預計這一差距會隨著規模和訓練時間的延長而縮小(隨著二元語法頻率的增加而增加)(圖 3,左上角)。
生成吞吐量
生成可以分解為提示「prefill 處理」和解碼「下一個 token 預測」兩步。相較於標準的僅解碼器迴圈模型,JRT-RNN 不會修改解碼步驟,因此討論重點在 prefill 階段。
使用 Simran Arora 等人論文《Simple linear attention language models balance the recall-throughput tradeof》中提出的 Based CUDAn 核心,JRT-Prompt 在處理 prefill 時吞吐量分別是 FlashAttention-2 和 FLA Triton 核心的 11.9 和 13.7 倍,如下表 5 所示。
當研究者將批大小增加到 64 時,JRT-Prompt 吞吐量分別是 FlashAttention-2 和 FLA Triton 核心的 6.1 倍和 7.2 倍。
接下來他們擴充套件了 Based 核心以支援 JRT-RNN,並且證明了當將序列長度增加到 32768 時,吞吐量分別是 FlashAttention-2 和 FLA 的 19.2 倍和 22.0 倍。當將批大小增加到 64 時,JRT-RNN 分別又提供了 9.7 倍和 11.5 倍的吞吐量提升。JRT-RNN 所需的時間是 Based prefill 的 1.24 倍,比 JRT-Prompt 更加高效。
更多技術細節和實驗結果請參閱原論文。