近8年後,谷歌Transformer繼任者「Titans」來了,上下文記憶瓶頸被打破

机器之心發表於2025-01-15

正如論文一作所說,「新架構 Titans 既比 Transformer 和現代線性 RNN 更有效,也比 GPT-4 等超大型模型效能更強。」


終於,在 2017 年推出影響 AI 行業長達 8 年的 Transformer 架構之後,谷歌帶來了全新的架構 Titans。這次,谷歌的重點是將推理領域非常重要的測試時(test-time)計算用在了記憶(memory)層面。

在談到推出 Titans 的初衷時,論文一作 Ali Behrouz 表示,「注意力機制一直是大多數 LLM 進展的重要組成部分,不過它無法擴充套件到長上下文。因此,Titans 應運而出,它成為了一種同時具備註意力機制和元上下文記憶的結構,可以在測試時學習記憶。該架構可以將上下文視窗擴充套件到 200 萬 tokens。」
圖片
圖源:https://x.com/behrouz_ali/status/1878859086227255347

這意味著,谷歌 Transformer 迎來了它的「繼任者」。
圖片
圖源:https://x.com/mark_k/status/1878896628654022993

多年來,研究人員一直在廣泛探究如何有效地利用迴圈模型和注意力機制,其中迴圈模型旨在將資料壓縮到固定大小的記憶(稱為隱狀態)中,而注意力機制允許處理整個上下文視窗,捕捉所有 token 的直接依賴。不過,更準確的依賴建模往往伴隨著二次成本,導致模型只能處理固定長度的上下文。

因此,谷歌提出了一種新的長期神經記憶模組(neural memory module),它能夠學習記憶歷史上下文,並幫助注意力機制在利用過去已久資訊的同時處理當前上下文。結果表明,這種神經記憶具有快速並行化訓練的優勢,同時還能保持快速推理。

從記憶的角度來看,谷歌認為注意力機制雖然受限於上下文但可以更準確地建模依賴關係,因此可以起到短期記憶的作用;而神經記憶能夠對資料進行記憶,起到了長期、更持久的記憶作用。基於這兩個模組,谷歌引入了一個全新的系列架構 —— Titans,透過三種變體有效地將記憶融合到該系統架構中,它們分別是記憶作為上下文(Memory as a Context,MAC)、記憶作為門(Memory as a Gate,MAG)和記憶作為層(Memory as a Layer,MAL)

在語言建模、常識推理、基因組學和時序預測任務上的實驗結果表明,Titans 架構比 Transformer 和近年來的現代線性迴圈模型更有效。另外,在大海撈針(needle-in-haystack)中,Titans 架構能夠有效地擴充套件到超過 200 萬 tokens 的上下文視窗,並且比基準模型實現了更高的準確性。
圖片
  • 論文標題:Titans: Learning to Memorize at Test Time

  • 論文地址:https://arxiv.org/pdf/2501.00663v1

另外,論文作者之一 Peilin Zhong 為谷歌 NYC 演算法與最佳化團隊的研究科學家,2021 年加入谷歌。他本科畢業於清華姚班,博士畢業於哥倫比亞大學。

目前,已經有人搞出了有關 Titans 架構的非官方實現,感興趣的讀者可以去看一下。
圖片
GitHub 地址:https://github.com/lucidrains/titans-pytorch

學習測試時記憶

谷歌詳細介紹了長期神經記憶模組,它成為了一種可以在測試時學習記憶的元模型。

長期記憶

為了設計一個長期神經記憶模組,我們需要模型能夠將過去歷史的抽象編碼到其引數中。因此,一個簡單的思路是訓練神經網路並期望它能夠記住自己的訓練資料,然而記憶幾乎一直是神經網路中令人頭疼的現象,它限制了模型的泛化能力,還引發隱私問題,因此導致測試時效能不佳。

基於此,谷歌認為需要一個線上元模型來學習如何在測試時記憶或忘記資料。在這種設定下,模型學習一個能夠記憶的函式,但不會過擬合訓練資料,從而在測試時實現更好的泛化效能。

學習過程和意外指標(Learning Process and Surprise Metric)。訓練長期記憶的關鍵思路是將訓練視為線上學習問題,其中將過去資訊 x_1, …, x_t-1 壓縮到長期神經記憶模組中。人類往往能夠記住背離預期(令人驚訝)的事件,受此啟發,模型意外可以簡單定義為它相對於輸入的梯度。梯度越大,輸入資料與過去資料的偏差就越大。因此,使用這個意外分數,可以將記憶更新如下:
圖片
這一意外指標可以導致在重大意外時刻之後出現重要資訊缺失。從人類記憶的角度來看,即使一個事件令人難忘,但它可能不會在長時間內持續讓我們感到驚訝。為了改進這一現象,谷歌將意外指標分解為了(1)過去意外,它衡量最近過去的意外程度;(2)瞬時意外,它衡量傳入資料的意外。
圖片
這些意外指標基於一個損失函式圖片,它就是我們的記憶在測試時學習充當的目標。也就是說,記憶模組是一個元模型,它基於損失函式圖片來學習一個函式。

在本文中,谷歌則專注於聯想記憶,目的是將過去的資料儲存為鍵(keys)和值(values)對。類似於 Transformer,在給定 x_t 的情況下,谷歌使用兩個線性層將 x_t 投影到鍵和值中:
圖片
接下來,谷歌希望記憶模組可以學習鍵和值之間的關聯,為此將損失定義如下:
圖片
遺忘機制(Forgetting Mechanism)。在處理非常大的序列(比如百萬 tokens)時,管理哪些過去資訊應該被遺忘非常重要,即使使用深度或者非常大的矩陣值記憶時也是如此。因此,谷歌使用了一種自適應遺忘機制,允許記憶忘記不再需要的資訊,從而更好地管理有限的記憶容量。也就是說,給定下一個 token x_t,谷歌將更新規則做如下修改:
圖片
記憶架構(Memory Architecture)。谷歌重點將具有 L_M≥1 層的簡單 MLP 作為長期記憶架構,選擇它們的原因在於希望能夠更好地激勵長期記憶設計以及將其融入架構的方法。谷歌表示,本文的架構開闢了一個新的研究方向,有助於設計更有效且高效記憶資料的神經架構。

檢索記憶(Retrieving a Memory)。在探討如何設計和訓練一個可以在測試時學習記憶的長期記憶模組之後,剩下的關鍵問題便是如何從記憶中檢索資訊?谷歌僅僅使用了沒有更新權重的前向傳遞(即推理)來檢索與查詢相對應的記憶。在形式上,給定一個輸入 x_t,谷歌使用線性層 W_Q 來投影輸入,即 q_t = x_tW_Q,並透過以下公式從記憶 y_t 中檢索相應(或有用)的資訊。
圖片
並行化長期記憶訓練

理論上,長期記憶模組的訓練需要圖片 FLOPS,其中 N 為序列長度。不過在實踐中,我們需要並行化訓練過程並充分利用 TPU、GPU 等硬體加速器,同時需要張量化該過程並使用更多矩陣乘法(matmuls)。

接下來,谷歌表示,使用小批次梯度下降、資料學習率和權重衰減來計算內迴圈權重的方式可以重新來表示,以便它只使用矩陣乘法和求和(sums)。這裡將序列拆分為大小為 b ≥ 1 的塊,並將小批次梯度下降寫做:
圖片
此外,谷歌解釋了 M_t = W_t 為線性時的情況。對於 N_p ≥ 2 的 MLP,過程類似。使用本文的損失函式可以得到如下:
圖片
最後,谷歌擴充套件這一表示,以便可以合併動量項。在具有動量的塊式梯度下降中,如果觀察動量項則可以得到如下:
圖片
作為函式塊的引數(Parameters as the Function of Chunks)。谷歌沒有讓引數 a_t、θ_t 和 η_t 依賴於輸入,而是讓它們成為函式塊。儘管失去了表達能力,但可以幫助更快地訓練。在這種情況下,谷歌在每個塊中對每一個 a、θ 和 η 都使用了相同的值。在實驗中,谷歌將這些引數作為了 token 的函式,並表示,這種簡化(即作為塊函式)可能是未來工作感興趣的地方,以便以更高效的方式訓練更大的模型。

下圖 1 展示瞭如何並行並在使用矩陣乘法時完成神經記憶訓練。
圖片
如何融合記憶?

接下來需要解決的一個重要問題是:如何有效且高效地將神經記憶融合到深度學習架構中

從記憶的角度來看,Transformer 中的 K 和 V 矩陣對可以解釋為聯想記憶塊。由於它們對依賴關係的精確建模以及有限的上下文視窗,它們可以被用作短期記憶模組,以處理當前上下文視窗大小。另一方面,神經記憶能夠不斷從資料中學習並儲存在其權重中,因而可以發揮長期記憶的作用。谷歌透過三個不同的 Titans 變體來回答以上問題。

記憶作為上下文(Memory as a Context,MAC)

Titans 的第一個變體 MAC 的架構設計如下圖 2 所示,將記憶作為當前資訊的上下文。
圖片
也就是說,給定一個長序列圖片,首先將序列分成固定大小的片段 S^(𝑖),其中 𝑖 = 1,...,𝑁/𝐶。給定傳入片段 S^(𝑡),谷歌將它視為當前上下文,將其過去的片段視為歷史資訊。因此,谷歌讓 M_𝑡-1 成為片段 S^(𝑡) 之前的長期記憶狀態,使用輸入上下文作為對記憶 M^𝑡-1 的查詢,以從長期記憶中檢索相應的資訊。谷歌如下所示檢索與 S^(𝑡) 相對應的過去資訊:
圖片
接下來將這些歷史資訊與持久記憶引數一起作為注意力模組的輸入序列:
圖片
該架構具有兩個關鍵優勢:一是注意力模組同時具有歷史和當前上下文,能夠根據當前資料決定是否需要長期記憶資訊,二是注意力模組幫助長期記憶只儲存來自當前上下文的有用資訊。這意味著,並非每個片段中的所有 token 都是有用的,記憶所有 token 可能會導致記憶體溢位。因此,注意力模組幫助記憶瞭解哪些資訊是有用的,從而更好地管理記憶體容量。

另外,在測試時,(i)持久記憶引數是固定的,它們編碼了有關任務的知識,不應改變;(ii)注意力模組權重是上下文學習器;(iii)長期記憶模組在測試時仍然學習(記憶)資訊。也就是說,即使在測試時,神經記憶的權重也會更新,這是因為權重對過去已久的抽象進行了編碼。

記憶作為門(Memory as a Gate,MAG)

Titans 第二個變體 MAG 的架構設計如下圖 4 所示:
圖片
在其中一個分支中,谷歌直接使用輸入資料來更新長期記憶;在第二個分支中,谷歌使用了滑動視窗注意力(SWA):
圖片
該架構的整體注意力掩碼如下圖 3b 所示,其中滑動視窗注意力(SWA)充當精確的短期記憶,而神經記憶模組充當模型的衰減記憶。該設計也可以看作是多頭架構,其中各頭的結構不同。
圖片
記憶作為層(Memory as a Layer,MAL)

Titans 的第三個變體 MAL 使用了深度神經網路,這種架構設計在文獻中更為常見,其中混合模型堆疊具有完整或滑動視窗注意力的迴圈模型。
圖片
給定輸入 x,可以得到以下:
圖片
其中 SW-Attn 是滑動視窗注意力。

無注意力記憶(Memory Without Attention)。從記憶的角度來看,谷歌期望記憶系統的每個元件都能獨立工作,即使其他元件受到了干擾。因此,即使沒有短期記憶(即注意力),長期記憶模組仍然應該是一個強大的模型。谷歌在實驗中將這種變體稱為 Titans (LMM)。

架構細節

在所有塊中,谷歌使用了殘差連線;在實現中,谷歌使用 SiLU (.) 啟用函式作為計算查詢、鍵和值的非線性啟用,並使用圖片對查詢和鍵進行歸一化。

卷積(Convolution)。遵循最近的現代線性迴圈模型,谷歌在每個查詢、鍵和值投影后都融合了一個 1D 深度可分離卷積層。這些 1D 卷積可以提升效能,並且計算高效。

門控(Gating)。谷歌還在最終輸出投影之前利用線性層進行歸一化和門控。

實驗結果

谷歌在實驗部分關注上述三種 Titans 變體,分別是 MAC、MAG 和 MAL,以及單獨的神經記憶模組。對於每個模型,谷歌使用了四種尺寸的模型,引數分別是 (i) 170M、(ii) 340M、(iii) 400M 和 (iv) 760M。

語言建模

谷歌首先關注模型在語言建模和常識推理任務中的困惑度。下表 1 報告了 Titans 變體和三種不同大小(340M、400M 和 760M)基線的結果。在包括 Transformer++ 在內的非混合模型中,神經記憶模組在困惑度和準確度測量方面均取得了最佳效能。

谷歌還發現,Titans 的三種變體(MAC, MAG 和 MAL)都優於 Samba (Mamba + 注意力)和 Gated DeltaNet-H2(Gated DeltaNet + 注意力)。
圖片
大海撈針

下表 2 結果顯示,與基線相比,神經記憶模組均取得了最佳結果。

谷歌將這種卓越的表現歸因於 Titans 與現有序列模型的三個關鍵差異:(1)與 TTT 相比,神經記憶能夠透過使用動量和遺忘機制(即權重衰減)更好地處理記憶容量。因此,隨著序列長度的增加,神經記憶的效能不會下降,呈現出一致的趨勢;(2)與具有門控(遺忘)機制的 Mamba2 相比,Titans 具有深度非線性記憶,從而實現了更好的記憶管理。此外,與神經記憶和 DeltaNet 不同,Mamba2 無法移除記憶,因此在增加序列長度時,其效能會出現顯著下降;(3)與 DeltaNet 相比,儘管它能夠使用增量規則移除記憶,但無法擦除記憶,缺乏遺忘機制。

最終,正如預期的那樣,使用 Titans 變體時能看到相當或更好的結果,其中最佳結果來自 MAC。
圖片
BABILong 基準

在微調設定中,谷歌將小型微調版本的 Titans (MAC) 與其他模型進行了比較。

Titans 和基線的結果如下圖 6b 所示。Titans 的表現優於所有模型,甚至比 GPT4 這樣的超大型模型還要好。此外,與基於 Transformer 的 RMT 等記憶模型相比,Titans 表現出更好的效能,這主要歸功於其強大的記憶。
圖片
深度記憶的影響

接下來的實驗評估了深度記憶對 wall-clock 訓練時間和模型效能的影響。

下圖 7 中報告了 Titans(LMM)和基線的困惑度與序列長度的關係。有趣的是,隨著記憶深度的增加,該模型可以在所有序列長度上實現更好的困惑度。此外,當模型的引數量較少時,更深的記憶模組對序列長度的魯棒性更強。隨著引數量的增加,所有模型在較長的序列上都表現出更好的效能。
圖片
時序預測

為了展示記憶模組在更廣泛任務中的有效性,谷歌評估了 Titans 在時序預測任務中的表現。結果如下表 3 所示,谷歌的神經記憶模組優於所有基線,包括基於 Mamba、線性和 Transformer 的架構。
圖片
DNA 建模

谷歌還進一步評估了神經記憶模組在 DNA 建模任務上的表現,結果如下 4 所示,相較於當前的 SOTA 架構,Titans(LMM)在不同的下游基因組任務中仍具有競爭力。
圖片
效率

谷歌還對 Titans 與當前 SOTA 序列模型的效率進行了比較,下圖 9 顯示了不同序列長度 x 批大小的模型的訓練吞吐量。可以看到,谷歌神經記憶模組比 Mamba2 和 Gated DeltaNet 稍慢,不過 Titans (MAL) 比基線和神經記憶模組都要快。
圖片
更多技術細節和實驗結果請參閱原論文。

相關文章