鄂維南院士領銜新作:大模型不止有RAG、引數儲存,還有第3種記憶

机器之心發表於2024-07-10
2.4B 的 Memory3比更大的 LLM 和 RAG 模型獲得了更好的效能。

近年來,大型語言模型 (LLM) 因其非凡的效能而獲得了前所未有的關注。然而, LLM 的訓練和推理成本高昂,人們一直在嘗試透過各種最佳化方法來降低成本。

本文來自上海演算法創新研究院、北京大學等機構的研究者受人類大腦記憶層次結構的啟發,他們透過為 LLM 配備顯式記憶(一種比模型引數和 RAG 更便宜的記憶格式)來降低這一成本。從概念上講,由於其大部分知識都外化為顯式記憶,因而 LLM 可以享受更少的引數大小、訓練成本和推理成本。
圖片
  • 論文地址:https://arxiv.org/pdf/2407.01178
  • 論文標題:Memory3 : Language Modeling with Explicit Memory

作為初步的概念證明,研究者從零開始訓練了一個 2.4B 的 LLM,它比更大的 LLM 和 RAG 模型獲得了更好的效能,並實現了比 RAG 更高的解碼速度。這個模型被命名為 Memory3,因為在 LLM 中,顯式記憶是繼隱式記憶(模型引數)和工作記憶(上下文鍵值)之後的第三種記憶形式。
圖片
具體而言,本文引入了一種新的記憶格式,即顯式記憶,其特點是寫入成本和讀取成本相對較低。如圖 1 所示,模型首先將知識庫(或任何文字資料集)轉換為顯式記憶,實現為稀疏注意力鍵 - 值,然後在推理過程中呼叫這些記憶體並將其整合到自注意力層中。
圖片
新的記憶格式定義了新的記憶層次結構:
圖片
此外,本文還介紹了一種支援知識外化的記憶電路理論,並提出了可以讓儲存易於處理的記憶稀疏機制和促進記憶形成的兩階段預訓練方案。

總結而言:

  • Memory3 在推理過程中利用顯式記憶,減輕了模型引數記憶特定知識的負擔;
  • 顯式記憶是從構建的知識庫中編碼而來的,其中稀疏記憶格式保持了真實的儲存大小;
  • 研究者從頭開始訓練了一個具有 2.4B 非嵌入引數的 Memory3 模型,其效能超過了更大規模的 SOTA 模型。它還比 RAG 具有更好的效能和更快的推理速度;
  • 此外,Memory3 提高了事實性並減輕了幻覺,並能夠快速適應專業任務。

方法介紹

記憶電路理論有助於確定哪些知識可以儲存為顯式記憶,以及哪種模型架構適合讀取和寫入顯式記憶。
圖片
研究者將輸入輸出關係作為電路的內部機制,並將知識定義為輸入輸出關係及其電路。透過操縱這些電路,人們可以從 LLM 中分離出許多知識,同時保持其功能完好無損。

Memory3:在架構方面,本文的目標是為 Transformer LLM 設計一個顯式的記憶機制,使其寫入成本和讀取成本都比較低。此外,本文希望將對 Transformer 架構的修改限制在儘可能小的範圍內,不新增任何新的可訓練引數,這樣大多數現有的 Transformer LLM 都可以在幾乎不進行微調的情況下轉換為 Memory3 模型。簡單的設計過程如下:

寫入成本:在推理之前,LLM 將每個參考寫入顯式記憶,儲存在驅動器上。記憶是從自注意力層的鍵值向量中選擇的,因此寫入過程不涉及訓練。每個引用都是獨立處理的,避免了長上下文注意力的成本。

讀取成本:在推理過程中,顯式記憶從驅動器中檢索,並與通常的上下文鍵值一起由自注意力讀取。每個記憶由來自少量注意力頭的極少量鍵值組成,從而大大減少了額外的計算、GPU 儲存、驅動器儲存和載入時間。它允許 LLM 頻繁檢索許多參考,而對解碼速度的影響有限。

推理過程如圖 9 所示,每當 LLM 生成 64 個 token 時,它就會丟棄當前記憶,使用這 64 個 token 作為查詢文字來檢索 5 個新記憶,並繼續使用這些記憶進行解碼。同樣,在處理提示時,LLM 會為每 64 個 token 塊檢索 5 個記憶。每個塊都會關注自己的記憶,並且不同塊之間的記憶可能會有所不同。
圖片
寫入與讀取記憶:在推理過程中,LLM 可以透過其自注意力層直接讀取檢索到的顯式記憶,方法是將它們與上下文鍵值連線起來(圖 9)。具體來說,對於第 l 層的每個注意力頭 h,如果它被選為記憶頭,那麼它的輸出 Y^( l,h ) 將會改變:
圖片
此外,該研究對所有顯式記憶採用並行位置編碼,即所有鍵位置都位於長度為 128 的同一區間內,如圖 9 所示。

兩階段預訓練:預訓練由兩個階段組成,warmup 和持續訓練。只有持續訓練階段涉及顯式記憶,而 warmup 階段使用與普通預訓練相同的格式。
圖片
圖 13 繪製了 warmup 階段訓練損失和學習率時間表。
圖片
圖 14 繪製了持續訓練階段訓練損失和學習率時間表。
圖片
實驗結果

研究者評估了 Memory3 模型的一般能力(基準任務)、對話能力、專業能力(法律和醫學)以及幻覺。此外,研究者還測量了 Memory3 的解碼速度,並與類似和更大的 SOTA LLM 以及 RAG 模型進行了比較。

一般能力的評估結果如下所示,結果表明顯式記憶使平均分提高了 2.51%。相比之下,Llama2-7B 與 13B 的得分差距為 4.91%。顯式記憶可以將「有效模型大小」提高 2.51/4.91 ≈ 51.1%。
圖片
接下來作者評估了 Memory3 的對話技巧,結果列於表 18 中,表明模型以更少的引數勝過 Vicuna-7B、Falcon-40B-Instruct 和 ChatGLM2-6B。
圖片
目前,LLM 仍然面臨幻覺問題。從概念上講,Memory3 應該不太容易受到幻覺的影響,因為它的顯式記憶直接對應於參考文字。為了評估幻覺,研究者選擇了兩個英文資料集進行評估。結果如表 19 所示,Memory3 在大多數任務上都取得了最高分。
圖片
使用顯式記憶的一個好處是,LLM 可以透過更新其知識庫輕鬆適應新領域和任務。只需將與任務相關的參考匯入 Memory3 的知識庫,並可選擇在熱啟動的情況下將其轉換為顯式記憶。然後,該模型可以利用這些新知識進行推理,跳過成本更高且可能有損的微調過程,並且執行速度比 RAG 更快。圖 4 已證明這種成本降低,並且可以促進 LLM 在各個行業的快速部署。
圖片
下表表明,Memory3 的表現優於大多數模型。
圖片
最後,研究者透過每秒生成的 token 數來評估 Memory3 的解碼速度或吞吐量。
圖片
瞭解更多內容,請參考原論文。

相關文章