稀疏注意力再添一員,華為諾亞推出高效選擇注意力架構ESA
机器之心發表於2025-02-24
AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com
當 DeepSeek 的 NSA 與月之暗面的 MoBA 以稀疏注意力掀起長序列技術熱潮,行業對 “效率革命” 的追逐迎來關鍵一躍 —— 華為諾亞方舟實驗室正式釋出全新 ESA 演算法(Efficient Selective Attention)。論文地址:https://arxiv.org/pdf/2502.14477透過稀疏化注意力的創新設計,ESA 突破了大模型在長文字處理中的瓶頸。ESA 不僅實現了數倍序列長度的擴充,還引入獨創的動態計算正規化,結合鄰域影響力有效避免了單純選擇 top-ranked token 所帶來的效能損失。透過對關鍵 token 的精確選擇,ESA 在最佳化長序列處理效率的同時,提升了計算效能,為大模型在長序列任務中的應用帶來了新的可能性。在大語言模型的推理過程中,長序列模型的訓練需要極高的算力和海量資料支援,理想的解決方案是透過短序列的訓練成果外推到長序列。然而,隨著序列長度的增加,注意力計算的複雜度呈平方級增長,這使得高效且準確的長序列推理成為了一大挑戰。為此,研究人員提出了多種方法,以應對這一挑戰。ESA 方案正是在這一背景下提出的創新外推解決方案。ESA 透過對 query 和 key 的低維壓縮,有效減少了 token 選擇的計算複雜度。該方案透過靈活高效地選擇關鍵 token 進行注意力計算,大幅度降低了 LLMs 在處理長文字時的計算負擔,且在效能上與全注意力外推方法相當,甚至在高倍外推場景下優於全注意力演算法,實現了上下文長度的有效擴充套件。當大模型訓練長度有限,隨著序列長度的增長,一方面會出現 OOD (out-of-distribution) 的問題,另一方面注意力計算量會迅速增大。現有的研究表明,注意力矩陣具有稀疏性,對於長序列而言,稀疏程度進一步擴大。選擇性注意力(Selective Attention)利用了稀疏性這一特性,選擇部分 token 來計算注意力,結合外推的位置編碼能將短序列模型應用到長序列任務上的同時,顯著降低計算量。在計算稀疏注意力時細粒度的 token 選擇方法能夠更加靈活、精準地定位到關鍵資訊。然而,token 粒度選擇會引入巨大的計算開銷。這引出了一個核心的問題:如何在選擇性注意力方法中平衡靈活性與效率。針對這一挑戰,ESA 方法透過將 query 和 key 進行低維壓縮,顯著降低 token 選擇的計算複雜度,在外推場景下實現 token 粒度動態稀疏注意力機制。高效選擇:ESA 引入了一種基於 query 感知的 token 粒度選擇機制,基於壓縮後的 query 和 key 計算 token 的重要性分數,同時考慮周圍 token 的影響(鄰距影響力),以避免直接選擇 top-ranked token 導致的效能下降。注意力計算:在選擇關鍵 token 後,ESA 使用被選中的 token 的完整的 query 和 key 進行注意力計算,而非對所有前序 token 進行計算,從而大幅降低複雜度。2.ESA:基於 token 粒度的高效選擇性注意力ESA 的主要創新點在於透過 token 粒度選擇性注意力機制,在保持模型準確率的同時顯著降低計算複雜度。具體來說,與現有的長序列外推方法不同,ESA 提出了一種基於 token 的細粒度選擇注意力,能夠在 prefilling 和 decoding 階段動態選擇最關鍵的少量 token,而不是固定 block 選擇或者永久丟棄不重要的 token。首先,ESA 將 query 和 key 經過簡單的一層 MLP 壓縮到原有維度的大約 3.2%,在低維空間計算重要性分數,顯著降低計算複雜度;其次,根據重要性分數選擇 topk 的 token,控制 key 的長度是固定的,這樣將注意力計算由原有的平方複雜度降低為線性複雜度。雖然選擇 token 是平方複雜度,但是由於將 query 和 key 壓縮到了更低維的空間,使得對於算力要求大大降低。ESA 的具體實現方式如下:輸入序列的 token 被分為 4 部分,注意力包括全域性注意力和 window 的區域性注意力,初始 token 和 ESA 選擇的 topk 中間 token 拼接起來計算全域性注意力,localtoken 用於計算 window 的注意力,兩部分注意力進行融合計算最終的注意力。ESA 按照 chunked-prefill 快取 key 和 value,即基於當前 chunk 的 query 選擇重要的中間 tokens,計算 token 的重要性時兼顧當前的所有 query;在解碼階段,只需要考慮當前的一個 token 的 query 即可。如果計算中間某個 token 重要性,需要計算和當前所有 token 的重要性,其中單個 token 的重要性用 query 和 key 的點積表示:這裡 H 是 head 的數量,為了降低複雜度 ESA 整合了所有的 head。為了進一步降低計算複雜度,不要求準確計算重要性分數,而是更關注相對大小,ESA 將 query 和 key 分別透過一層 MLP 進行壓縮。ESA 採取 offline 的方式學習 MLP 的權重:ESA 使用一個小的校準資料集用模型進行推理,儲存中間的 query、key 和 value,用於訓練降維 MLP,只增加了極少量的降低 query 和 key 大小的網路權重,且無需對模型微調。為了確保分數的相對大小,避免某個 token 在重要性分數中佔據主導地位,ESA 對分數進行修正:進一步的,作者發現僅選擇 topk 的 token 模型在大海撈針任務中只能檢索到部分資訊,提出了鄰距影響力的概念,即對於某個中間的 token,其重要性分數不僅取決於自身的分數,還受到周圍 token 的影響,更新後的分數為:在選擇完重要 token 後,ESA 使用完整的 query、key 和 value 計算注意力,最終的注意力輸出如下所示:ESA 的計算複雜度降低主要來源於低維的 query 和 key 計算重要性分數以及選擇完成以後的線性注意力計算複雜度,經過理論計算,一步 attention 計算在長序列場景下能降低為原有的:實際實驗中我們將 query 和 key 壓縮為原有的 3.2%,一步 attention 計算量在輸入序列足夠長時理論能降低至 1.6% 左右。論文選擇開源訓練集 Pile 的 2 條 Books3 樣本收集用於訓練降維 MLP 的 qk 樣本,query 和 key 從 4096 壓縮為 128,壓縮比例約為 l3.2%,注意力計算的視窗長度約為 6k。為了將開源的短序列模型應用到長序列中,ESA 沿用了 Infllm 的外推位置編碼設定,使用 Llama-3-8B-Instruct 和 Mistral-7B-Instruct-v0.2,在多個公開的長序列基準測試中驗證了 ESA 的效能,包括 Longbench、InfiniteBench、NeedleBench 等。作者對比了 full attention 的外推方法和同型別的基於 window 的外推方法,且同型別方法的 window 長度一致。實驗結果表明,ESA 透過高效靈活選擇重要的 token,總體效能在外推倍數足夠大時候優於 full attention 的方法,且均明顯優於同型別的方法,尤其在 multi needles 檢索場景下例如數星星和 NeedleBench,在其他同型別方法失效的時候,ESA 仍然有較高的準確率。ESA 不對每個 head 單獨選擇 token,而是將所有 head 整合到一起計算重要性分數,有利於降低計算複雜度,提升效率,為了驗證這一操作對演算法的影響,作者做的對比實驗如下所示,可以看出這樣的整合對於演算法影響有限。論文研究了鄰距影響力的超引數影響,結果如下所示,對不同的測評集該引數的影響不同,取值較小有利於 multi needles 型別的檢索任務,取值較大則有利於 single needle 型別任務,這可能是由於單針檢索任務只需要關注 ground truth 所在的片段即可,增大鄰距影響力有利於 attention 集中到較長的片段上。ESA 有效平衡了長序列外推場景下的選擇性注意力中的靈活性和計算效率,用於在不進行模型引數增量微調的情況下擴充套件上下文長度。ESA 的核心思想是在每個步驟中選擇固定數量的最重要 token 來計算注意力,利用注意力矩陣的稀疏性。當輸入序列足夠長時,ESA 透過將 query 和 key 壓縮為低維表徵,有效降低選擇 token 的計算複雜度。實驗評估表明,ESA 能夠有效處理長度為訓練長度 4 倍甚至 25 倍的各種長序列任務。未來的研究需要探索更準確、更高效的選擇重要 token 的方法,以及軟硬體協同的高效外推方案。