優於o1預覽版,推理階段KV快取縮減一半,LightTransfer降本還能增效

机器之心發表於2025-03-10
圖片
LLM 在生成 long CoT 方面展現出驚人的能力,例如 o1 已能生成長度高達 100K tokens 的序列。然而,這也給 KV cache 的儲存帶來了嚴峻挑戰。為應對這一難題,“hybrid model” 成為了一條備受關注的可行路徑:它在標準 transformer 的部分層中引入更高效的注意力機制(如 RNN 或 sliding window attention),以替代原有的注意力層。近期的研究(如 minimax-01、gemma2 等)已經充分驗證了這種混合模型的有效性,但目前依然需要從頭訓練,尚未出現可以直接輕量級遷移已經訓練好的 dense transformer 模型到 hybrid model 的方案。

我們希望提出一種簡潔高效的方法,將已經預訓練完成的 dense transformer 模型順利轉換為 hybrid models。為此,我們提出了 LightTransfer,這一思路源於一個關鍵觀察:現有模型中存在大量呈現 “lazy” 特性的冗餘層 [1]。因此,一個直觀的想法就是將這些冗餘層替換為僅需常數大小 KV cache 的 streaming attention,從而無需維護完整的 KV cache,將 dense Transformer 轉變為更高效的 hybrid model。
圖片
圖片來源:https://arxiv.org/pdf/2309.17453
圖片
  • 專案主頁:https://sites.google.com/view/lighttransfer

  • Huggingface 模型:cxdu/QwQ-32B-LightTransfer

  • github 程式碼:https://github.com/sail-sg/LightTrans

LightTransfer-Train

1) 方法

LightTransfer 的方法非常直接:我們先在訓練集上跑一遍 benchmark,識別出最 “lazy”,也就是 lazy ratio 最高的 50% attention 層,然後將這些層替換為 streaming attention。lazy ratio 用來衡量模型在第 (i) 層的注意力分配:它統計了來自 Query 對初始和最近 key 的注意力權重之和,數值越高就代表該層的注意力越集中在這些 key 上,也就越 lazy。lazy ratio 的具體定義如下:
圖片
其中:
  • 圖片表示最後一部分的查詢(query)集合;

  • 圖片分別表示初始與最近部分的鍵(key)集合;

  • 圖片為在第 i 層從查詢 q 到鍵 k 的注意力權重。

圖片值越高,說明第 i 層對這些鍵的關注度越集中,也就越“lazy”。

QwQ 中每層的 lazy ratio 分佈如下:
圖片
2) 實驗結果

我們的主要實驗物件是 o1 類的長 CoT 生成模型。由於 QwQ 並未公開其訓練資料,我們遵循 STILL [2] 的方案,使用與其完全相同的訓練設定(包括資料集、訓練引數以及以 Qwen2.5-32B-Instruct 作為起點),唯一的差別在於,我們將 50% 的層換成 streaming attention。這樣就能在推理階段顯著縮減近一半的 KV cache。
圖片
從表中可以看出,LightTransfer 在 AIME24 & 25 以及 MathOAI 上的表現優於 QwQ-STILL 和 o1-preview。

LightTransfer-Test

1) Motivation

對於另外一種更為主流的長上下文理解(long context understanding)任務而言,輸入文字本身就非常冗長,因此在測試階段可以對模型進行即時(on-the-fly) 轉換。

2) 方法

基於這一點,我們提出了 LightTransfer-Test,使得模型在推理環節僅依賴 prefilling 的結果就能完成識別和轉換。然而,在實際操作中,我們也面臨了兩個問題:

問題 1:與 Flash Attention 的不相容

當前,Flash Attention 已成為標配,但它並不會顯式計算並儲存注意力權重 (attention weights);因此,如果我們想要獲得用於衡量 lazy ratio 的注意力資訊,就必須重新計算注意力權重,這會帶來不可忽視的額外開銷。

解決方案:為避免重複計算,我們借鑑了 online softmax 的思路,利用 Flash Attention 在計算過程中生成的 LSE(log-sum-exp)作為 lazy ratio 的分母。更值得注意的是,我們驚喜地發現分子的計算複雜度僅為 O (1),而若重新計算則需要 O (seq_len),因此這種方法有效地避免了大規模的重複開銷。具體演算法如下:
圖片
問題 2:prefilling 階段的峰值記憶體

若等到 prefilling 結束後才根據各層的 lazy ratio 進行識別和轉換,那麼整個 prefilling 階段所需的記憶體峰值並沒有減少。

解決方案:為了解決這個問題,我們設計了一種基於優先佇列的策略,保證在 prefilling 階段,所需的記憶體峰值不會超過設定閾值(即 50% 的 full KV + 50% 的 streaming KV)。具體地說,我們維護一個以 lazy ratio 為優先順序的佇列:在 prefilling 過程中,一旦佇列中排隊的層數超出預先設定的閾值(例如 50% 的網路層),我們會從佇列中移除 lazy ratio 最高的層,並將其 KV cache 切換為 streaming 版本。這樣就無需像 SqueezeAttention [3] 那樣等到 prefilling 完成後才壓縮 KV cache,從而有效避免了 prefilling 階段峰值記憶體居高不下的問題。LightTransfer 具體做法如下圖:
圖片
3) 實驗結果
圖片
從表中可以看出,LightTransfer-Test 在 LongBench 上相較於其他層間 KV cache 壓縮方法(如 MiniCache 和 SqueezeAttention)具有更好的表現。它在將近一半的 KV cache 被削減的情況下,四個模型的平均效能僅下降了 1.5%; 尤其是在擁有更多層數的 LlaMa 3-70B 上。

[1] Xiao et al. Efficient streaming language models with attention sinks. ICLR 2024.
[2] Min ei tal. Imitate, explore, and self-improve: A reproduction report on slow-thinking reasoning systems. arXiv 2024.
[3] Wang ei al. Squeezeattention: 2d management of kv-cache in llm inference via layer-wise optimal budget. ICLR 2025.

相關文章