MIT韓松團隊長上下文LLM推理高效框架DuoAttention:單GPU實現330萬Token上下文推理

机器之心發表於2024-10-24
圖片
AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

本文第一作者肖光烜是麻省理工學院電子工程與電腦科學系(MIT EECS)的三年級博士生,師從韓松教授,研究方向為深度學習加速,尤其是大型語言模型(LLM)的加速演算法設計。他在清華大學電腦科學與技術係獲得本科學位。他的研究工作廣受關注,GitHub上的專案累計獲得超過9000顆星,並對業界產生了重要影響。他的主要貢獻包括SmoothQuant和StreamingLLM,這些技術和理念已被廣泛應用,整合到NVIDIA TensorRT-LLM、HuggingFace及Intel Neural Compressor等平臺中。本文的指導老師為韓松教授(https://songhan.mit.edu/)

TL;DR:DuoAttention 透過將大語言模型的注意力頭分為檢索頭(Retrieval Heads,需要完整 KV 快取)和流式頭(Streaming Heads,只需固定量 KV 快取),大幅提升了長上下文推理的效率,顯著減少記憶體消耗、同時提高解碼(Decoding)和預填充(Pre-filling)速度,同時在長短上下文任務中保持了準確率。

圖片

  • 論文連結:https://arxiv.org/abs/2410.10819
  • 專案主頁及程式碼:https://github.com/mit-han-lab/duo-attention

單 GPU 實現 330 萬 Token 上下文推理演示影片:MIT韓松團隊長上下文LLM推理高效框架DuoAttention:單GPU實現330萬Token上下文推理
隨著大語言模型(Large Language Models,LLMs)在各類任務中的廣泛應用,尤其是在長上下文(Long-Context)場景中處理海量文字資訊,如何在保證模型效能的同時減少記憶體和計算成本,成為了一個亟待解決的難題。為此,來自 MIT、清華大學、上海交通大學、愛丁堡大學和 NVIDIA 的研究團隊聯合提出了 DuoAttention 框架。這項創新技術透過對大語言模型的注意力機制(Attention Mechanism)進行精細化設計,極大提高了長上下文推理的效率,並大幅降低了記憶體需求,在不犧牲模型準確性的前提下,推動了 LLM 在長上下文任務中的發展。
研究背景:長上下文處理的挑戰

現代大語言模型(如 Llama、GPT 等)在多輪對話、長文件摘要、影片和視覺資訊理解等任務中需要處理大量歷史資訊,這些任務往往涉及數十萬甚至上百萬個 token 的上下文資訊。例如,處理一篇小說、法律文件或影片轉錄內容,可能需要分析百萬級別的 token。然而,傳統的全注意力機制(Full Attention)要求模型中的每個 token 都要關注序列中的所有前序 token,這導致瞭解碼時間線性增加,預填充(Pre-Filling)時間呈二次增長,同時,KV 快取(Key-Value Cache)的記憶體消耗也隨著上下文長度成線性增長。當上下文達到數百萬 token 時,模型的計算負擔和記憶體消耗將達到難以承受的地步。

DuoAttention 的創新設計

針對這一問題,DuoAttention 框架提出了創新性的 “檢索頭(Retrieval Heads)” 與 “流式頭(Streaming Heads)” 的分離方法。這一設計的核心理念是:並非所有的注意力頭(Attention Heads)在處理長上下文時都需要保留完整的 KV 快取。研究團隊透過大量實驗發現,在長上下文推理任務中,只有一小部分注意力頭,即 “檢索頭”,需要對全部 token 進行關注,以獲取上下文中的關鍵資訊。而大多數注意力頭,即 “流式頭”,只需關注最近的 token 和注意力匯點(Attention Sinks),不需要儲存全部的歷史 KV 狀態。

圖片

圖 1 展示了在 Llama-2-7B 模型上使用全注意力機制的注意力圖(Attention Maps)。從圖中可以看到,檢索頭(Retrieval Heads)捕獲了上下文中如 "best"、"fruit" 和 "orange" 等關鍵資訊,這些資訊對於處理長上下文至關重要,因而需要完整的 KV 快取。而流式頭(Streaming Heads)則主要關注最近的 token 和注意力匯點,不需要保留所有歷史資訊。

DuoAttention 的工作原理
圖片
圖 2 說明了 DuoAttention 的基本工作原理。

框架透過以下幾種關鍵機制來最佳化推理過程:

  • 檢索頭的 KV 快取最佳化:DuoAttention 為檢索頭保留完整的 KV 快取,這些頭對長距離依賴資訊的捕捉至關重要。如果對這些頭的 KV 快取進行剪裁,將導致模型效能嚴重下降。因此,檢索頭需要對上下文中的所有 token 保持 “全注意力(Full Attention)”。
  • 流式頭的輕量化 KV 快取:流式頭則主要關注最近的 token 和注意力匯點。這意味著它們只需要一個固定長度的 KV 快取(Constant-Length KV Cache),從而減少了 KV 快取對記憶體的需求。透過這種方式,DuoAttention 能夠以較低的計算和記憶體代價處理長序列,而不會影響模型的推理能力。

圖片

  • 檢索頭的自動識別:為了準確區分哪些頭是檢索頭,DuoAttention 提出了一種輕量化的最佳化演算法,使用合成資料集來訓練模型自動識別重要的檢索頭。這種最佳化策略透過密碼召回任務(Passkey Retrieval),確定哪些注意力頭在保留或丟棄 KV 快取後對模型輸出有顯著影響。最終,DuoAttention 在推理時根據這一識別結果,為檢索頭和流式頭分別分配不同的 KV 快取策略。

圖片

圖 3 展示了 DuoAttention 使用的合成資料集中的一個樣例。圖 4 展示了 DuoAttention 最終確定 LLM 中各個注意力頭的類別。

效能與準確率實驗

為了驗證 DuoAttention 框架的有效性,研究團隊在多種主流 LLM 架構上進行了廣泛的實驗評估,包括 Llama-2、Llama-3 和 Mistral 模型。實驗不僅測試了 DuoAttention 在記憶體與計算效率上的提升,還透過長上下文和短上下文任務對模型的準確率進行了全面測試。

1.長上下文任務的評估:在 Needle-in-a-Haystack(NIAH)基準測試中,DuoAttention 在極深的上下文條件下表現卓越,保持了高精度,並在處理 1048K 個 token 的長上下文時,依然能夠保持穩定的準確率,而其他方法由於丟失關鍵資訊導致效能下降顯著。在 14 個 LongBench 基準測試中,DuoAttention 展現了在不同任務下的強大泛化能力,能夠以較低的 KV 快取預算,提供接近全注意力機制的準確性。在多頭注意力模型(MHA)上,DuoAttention 使用 25% 的 KV 快取預算即可在多數任務中取得與全快取相當的效果,而在分組查詢注意力模型(GQA)上,50% 的 KV 快取預算即可維持高精度表現。

圖片

2.短上下文任務的評估:在 MMLU(多項選擇題)、MBPP(程式設計能力)和 MT-Bench(幫助能力)等短上下文基準上,DuoAttention 也表現出色。在使用 50% 流式頭的情況下,DuoAttention 的表現幾乎與全注意力機制一致,保持了 LLM 在短文字任務上的原始能力。例如,在 MMLU 基準上,DuoAttention 僅以 0.03% 的差距(79.35% 對比 79.38%)實現了與全注意力機制的相近效能。
圖片
記憶體與效率的提升

  • 記憶體消耗顯著降低:DuoAttention 在多頭注意力模型(Multi-Head Attention,MHA)上將記憶體消耗減少了 2.55 倍,在分組查詢注意力模型(Grouped-Query Attention,GQA)上減少了 1.67 倍。這是由於對流式頭採用了輕量化的 KV 快取策略,使得即使在處理百萬級別的上下文時,模型的記憶體佔用依然保持在較低水平。
  • 解碼(Decoding)和預填充(Pre-Filling)速度提升:DuoAttention 的解碼速度在 MHA 模型中提升了 2.18 倍,在 GQA 模型中提升了 1.50 倍。在預填充方面,MHA 和 GQA 模型的速度分別加快了 1.73 倍 1.63 倍,有效減少了長上下文處理中的預填充時間。

圖片

  • 百萬級 token 處理能力:結合 4 位元量化(Quantization)技術, DuoAttention 實現 Llama-3-8B 在單個 A100 GPU 上處理高達 330 萬 token 的上下文,這一結果是標準全注意力機制的 6.4 倍。
圖片
應用場景與未來展望

DuoAttention 框架為處理長上下文的應用場景帶來了巨大的變革,特別是在需要大規模上下文處理的任務中表現突出,包括:

  • 多輪對話系統(Multi-Turn Dialogues):DuoAttention 使對話模型能夠高效處理長時間對話記錄,從而更好地理解使用者上下文,提升互動體驗。
  • 長文件處理與摘要生成:在文件分析、法律文字處理、書籍摘要等任務中,DuoAttention 極大減少記憶體佔用,同時保持高精度,使長文件處理更加可行。
  • 視覺與影片理解:在涉及大量幀的上下文資訊處理的視覺和影片任務中,DuoAttention 為視覺語言模型(Visual Language Models,VLMs)提供了高效推理方案,顯著提升了處理速度。

研究團隊期望 DuoAttention 框架能夠繼續推動 LLM 在長上下文處理領域的發展,併為更多實際應用場景帶來顯著提升。

相關文章