徹底改變語言模型:全新架構TTT超越Transformer,ML模型代替RNN隱藏狀態

机器之心發表於2024-07-10

從 125M 到 1.3B 的大模型,效能都有提升。


難以置信,這件事終於發生了。

一種全新的大語言模型(LLM)架構有望代替至今在 AI 領域如日中天的 Transformer,效能也比 Mamba 更好。本週一,有關 Test-Time Training(TTT)的論文成為了人工智慧社群熱議的話題。

圖片

論文連結:https://arxiv.org/abs/2407.04620

該研究的作者來自史丹佛大學、加州大學伯克利分校、加州大學聖迭戈分校和 Meta。他們設計了一種新架構 TTT,用機器學習模型取代了 RNN 的隱藏狀態。該模型透過輸入 token 的實際梯度下降來壓縮上下文。

該研究作者之一 Karan Dalal 表示,他相信這將根本性的改變語言模型方法。
圖片
在機器學習模型中,TTT 層直接取代 Attention,並透過表達性記憶解鎖線性複雜性架構,使我們能夠在上下文中訓練具有數百萬(有時是數十億)個 token 的 LLM。

作者在 125M 到 1.3B 引數規模的大模型上進行了一系列對比發現,TTT-Linear 和 TTT-MLP 均能匹敵或擊敗最強大的 Transformers 和 Mamba 架構方法。

TTT 層作為一種新的資訊壓縮和模型記憶機制,可以簡單地直接替代 Transformer 中的自注意力層。

圖片

與 Mamba 相比,TTT-Linear 的困惑度更低,FLOP 更少(左),對長上下文的利用更好(右):

圖片

這不僅在理論上是線性的複雜度,而且實際執行時間也更快。

圖片

  • 在論文上線後,作者公開了程式碼與 jax 以供人們訓練和測試:https://github.com/test-time-training/ttt-lm-jax
  • 還有 PyTorch 推理程式碼:https://github.com/test-time-training/ttt-lm-pytorch

方法介紹

長上下文的挑戰是 RNN 層本質上所固有的:與自注意力機制不同,RNN 層必須將上下文壓縮為固定大小的隱藏狀態,更新規則需要發現數千甚至數百萬個 token 之間的底層結構和關係。

研究團隊首先觀察到自監督學習可以將大量訓練集壓縮為 LLM 等模型的權重,而 LLM 模型通常表現出對其訓練資料之間語義聯絡的深刻理解。

受此觀察的啟發,研究團隊設計了一類新的序列建模層,其中隱藏狀態是一個模型,更新規則是自監督學習的一個步驟。由於更新測試序列上的隱藏狀態的過程相當於在測試時訓練模型,因此研究團隊將這種新的層稱為測試時訓練(Test-Time Training,TTT)層。

圖片

研究團隊引入兩個簡單的例項:TTT-Linear 和 TTT-MLP,其中隱藏狀態分別是線性模型和兩層 MLP。TTT 層可以整合到任何網路架構中並進行端到端最佳化,類似於 RNN 層和自注意力。

圖片

為了讓 TTT 層更加高效,該研究採取了一些技巧來改進 TTT 層:

首先,類似於在常規訓練期間對小批次序列採取 gradient step 以獲得更好的並行性,該研究在 TTT 期間使用小批次 token。

圖片

圖片

其次,該研究為每個 TTT 小批次內的操作開發了一種雙重形式,以更好地利用現代 GPU 和 TPU。雙重形式的輸出與簡單實現等效,但訓練速度快了 5 倍以上。如圖 3 所示,TTT-Linear 在 8k 上下文中比 Transformer 更快,與 Mamba 相當。

研究團隊認為:所有序列建模層都可以看作將歷史上下文儲存到隱藏狀態,如圖 4 所示。

圖片

例如,RNN 層(如 LSTM、RWKV 和 Mamba 層)將上下文壓縮為跨時間的固定大小狀態。這種壓縮會產生兩種後果:一方面,將輸入標記 x_t 對映到輸出 token z_t 是高效的,因為每個 token 的更新規則和輸出規則都需要恆定的時間。另一方面,RNN 層在長上下文中的效能受限於其隱藏狀態 s_t 的表現力。

自注意力也可以從上述角度來看待,只不過它的隱藏狀態(通常稱為 Key-Value 快取)是一個隨 t 線性增長的列表。它的更新規則只是將當前的 KV 元組(tuple)追加到該列表中,而輸出規則則掃描 t 前的所有元組,以形成注意力矩陣。隱藏狀態明確儲存了所有歷史上下文,無需壓縮,這使得自注意力在長上下文方面比 RNN 層更具表現力。然而,掃描這個線性增長的隱藏狀態所需的時間也是線性增長的。為了保持長上下文的高效和表現力,研究者需要一種更好的壓縮啟發式。具體來說,需要將成千上萬或可能上百萬的 token 壓縮到一個隱藏狀態中,從而有效捕捉它們的底層結構和關係。這聽起來似乎有些高難度,但實際上很多人都對這種啟發式非常熟悉。

骨幹架構。將任何 RNN 層整合到更大架構中的最簡潔方法是直接替換 Transformer 中的自注意力,在這裡稱為骨幹。然而,現有的 RNN(如 Mamba 和 Griffin 等)都使用了與 Transformer 不同的骨幹層。最值得注意的是,它們的骨幹層在 RNN 層之前包含了時間卷積,這可能有助於收集跨時間的區域性資訊。在對 Mamba 主幹網進行試驗後,研究者發現它也能改善 TTT 層的困惑度,因此將其納入了建議方法中,詳見圖 16。

圖片

實驗結果

在實驗中,研究者將 TTT-Linear 、 TTT-MLP 與 Transformer、Mamba 這兩種基線進行了比較。

短文字

從圖 11 中可以得出以下結論:

  • 2k 上下文,TTT-Linear (M)、Mamba 和 Transformer 的效能相當,因為線條大多重疊。在 FLOP 預算較大的情況下,TTT-MLP (M) 的效能稍差。儘管 TTT-MLP 在各種模型大小下都比 TTT-Linear 有更好的困惑度,但 FLOPs 的額外成本抵消了這一優勢。
  • 8k 上下文,TTT-Linear (M) 和 TTT-MLP (M) 的表現都明顯優於 Mamba,這與 2k 上下文中的觀察結果截然不同。即使是使用 Transformer 主幹網路的 TTT-MLP (T) 在 1.3B 左右也比 Mamba 略勝一籌。一個顯著現象是,隨著上下文長度的增加,TTT 層相對於 Mamba 層的優勢也在擴大。
  • 上下文長度達到 8k,Transformer 在每種模型尺寸下的困惑度依舊錶現不錯,但由於 FLOPs 成本的原因,已不具競爭力。

圖片

上圖結果展示了將 TTT 層從 Mamba 主幹網路切換到 Transformer 主幹網路的影響。研究者假設,當序列建模層的隱藏狀態表現力較低時,Mamba 主幹網路中的時序卷積更有幫助。線性模型的表現力低於 MLP,因此從卷積中獲益更多。

長文字:書籍

為了評估長上下文的能力,研究者使用 Pile 的一個流行子集 Books3,以 2 倍的增量對 1k 到 32k 的上下文長進行實驗。這裡的訓練方法與 Pile 相同,並且 TTT 層的所有實驗都在一次訓練執行中完成。從圖 12 中的結果子集,他們得出了以下觀察結果:

圖片

在 Books 的 2k 上下文中,Pile 2k 的所有觀察結果仍然成立,只是 Mamba 現在的表現略好於 TTT-Linear(而它們的線條在 Pile 2k 中大致重疊)。

在 32k 上下文中,TTT-Linear (M) 和 TTT-MLP (M) 的表現都優於 Mamba,類似於 Pile 8k 的觀察結果。即使是採用 Transformer 主幹的 TTT-MLP (T) 在 32k 上下文中的表現也略好於 Mamba。

TTT-MLP (T) 在 1.3B 規模下僅略差於 TTT-MLP (M)。如上所述,由於缺乏清晰的線性擬合,很難得出經驗縮放定律。然而,TTT-MLP (T) 的強勁趨勢表明,Transformer 主幹可能更適合更大的模型和更長的上下文,超出了我們的評估範圍。

時鐘時間

LLM 的訓練和推理可分解為前向、後向和生成。推理過程中的提示詞處理(也稱為預填充)與訓練過程中的前向運算相同,只是後向操作不需要儲存中間啟用值。

由於前向(訓練和推理過程中)和後向都可以並行處理,因此這裡使用了雙重形式。生成新 token(也稱為解碼)本質上是順序性的,因此這裡使用了原始形式。

研究者提到,由於資源限制,本文實驗使用 JAX 編寫,並在 TPU 上執行。在 v5e-256 TPU pod 上,Transformer 基線在上下文為 2k 的情況下每次迭代訓練需要 0.30 秒,而 TTT-Linear 每次迭代需要 0.27 秒,在沒有任何系統最佳化的情況下快了 10%。鑑於 Mamba(用 PyTorch、Triton 和 CUDA 實現)只能在 GPU 上執行,為了進行公平比較,研究者將本文方法進行初步系統最佳化,使其能在 GPU 上執行。

圖 15 左側顯示了各個模型的前向核心在批大小為 16 時的延遲。所有模型都是 1.3B(Mamba 為 1.4B)。值得注意的是,這裡的 Transformer 基線要比 Mamba 論文中的快得多,因為此處使用了 vLLM ,而不是 HuggingFace Transformer 。

圖片

此外,研究者還編寫了另一個用於生成的 GPU 核心,並在圖 15 右側以批大小 512 為基準測試其速度。另一個常用的掛鐘時間(wall-clock time)指標是吞吐量(throughput),它考慮了使用更大的批大小的潛在好處。對於吞吐量,上述所有觀察結果和方法之間的排序仍然有效。

主要作者

在 TTT 研究提交後,論文作者之一,UCSD 助理教授 Xiaolong Wang 發推表示祝賀。他表示,TTT 的研究持續了一年半,但測試時間訓練(TTT)這個想法從誕生到現在其實已經過去了五年時間。雖然當初的想法和現在的成果完全不同了。

圖片

TTT 論文的三位主要作者分別來自於史丹佛、UC Berkeley 和 UCSD。

其中 Yu Sun 是史丹佛大學的博士後,他博士畢業於 UC Berkeley EECS,長期以來一直的研究方向就是 TTT。

圖片

Xinhao Li 是 UCSD 在讀博士,他本科畢業於電子科技大學。

圖片

Karan Dalal 是 UC Berkeley 在讀博士,他曾在高中時與他人共同創辦了一家名為 Otto 的獸醫遠端醫療初創公司。

圖片

上述三人,都把 test-time training 寫在了個人網站介紹研究方向的第一行。

更多研究細節,可參考原論文。

相關文章