4位元量化三倍加速不掉點!清華即插即用的SageAttention迎來升級

机器之心發表於2024-12-26
圖片
AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

論文共同第一作者張金濤、黃浩峰分別來自清華大學計算機系和交叉資訊研究院,論文通訊作者陳鍵飛副教授及其他合作作者均來自清華大學計算機系。

大模型中,線性層的低位元量化已經逐步落地。然而,對於注意力模組,目前幾乎各個模型都還在用高精度(例如 FP16 或 FP32)的注意力運算進行訓練和推理。並且,隨著大型模型需要處理的序列長度不斷增加,Attention(注意力運算)的時間開銷逐漸成為主要開銷。

此前,清華大學陳鍵飛團隊提出的 8-Bit 的即插即用 Attention(SageAttention),將 Attention 中的 QK^T 量化至 INT8,將 PV 保持為 FP16 精度並使用 FP16 精度的矩陣乘法累加器,同時提出 Smooth K 技術保持了量化 Attention 的精度,實現了 2 倍加速於 FlashAttention2,且在各類大模型上均保持了端到端的精度表現。

目前,SageAttention 已經被業界及社群廣泛地使用於各種開源及商業大模型中,比如 CogvideoX、Mochi、Flux、Llama3、Qwen 等。

近日,陳鍵飛團隊進一步提出了 4-Bit 的即插即用 Attention(SageAttention2),相較於 FlashAttention2 和 xformers 分別實現了 3 倍以及 4.5 倍的即插即用的推理加速,且在影片、影像、文字生成等大模型上均保持了端到端的精度表現
圖片
  • 論文標題:SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization

  • 論文連結:https://arxiv.org/abs/2411.10958

  • 開原始碼:https://github.com/thu-ml/SageAttention

即插即用舉例

SageAttention2 實現了高效的 Attention 運算元,可以實現即插即用的推理加速。輸入任意 Q, K, V 矩陣,SageAttention2 可以快速返回 Attention Output (O)。
圖片
具體來說,SageAttention2 使用起來很方便,克隆倉庫(git clone https://github.com/thu-ml/SageAttention)並執行 python setup.py install 後,只需一行程式碼便可以得到 Attention 的輸出,可以使用該介面方便地替換任意模型中的 Attention 函式:
圖片
效果上,以開源影片生成模型 CogvideoX-1.5-5B 為例,使用 SageAttention2 可以端到端加速 1.8 倍,且生成的影片無損:4位元量化三倍加速不掉點!清華即插即用的SageAttention迎來升級 使用全精度 Attention 4位元量化三倍加速不掉點!清華即插即用的SageAttention迎來升級 使用 SageAttention2

更重要的是,SageAttention2 提供了比 SageAttention 更廣泛的硬體支援。除了在 RTX 4090 上可以 3 倍加速於 FlashAttention 外,在 L20、L40、L40S 可以實現 2 倍的加速,在 A100、A800、A6000 上可以實現 1.45-1.6 倍的加速(基於 SageAttention)。

接下來,研究團隊將從前言、挑戰、方法以及實驗效果四個方面介紹 SageAttention2(總體流程圖如下圖)。
圖片
前言

隨著大模型需要處理的序列長度越來越長,Attention 的速度最佳化變得越來越重要。下圖展示了一個標準的 Transformer 模型中各運算的時間佔比隨序列長度的變化:
圖片
為了方便指代注意力運算中的矩陣,我們先回顧一下注意力的計算公式:
圖片
儘管 SageAttention 提出將 Q,K 量化至 INT8,將 P,V 保持 FP16 精度且採用 FP16 的矩陣乘法累加器來加快 Attention 的速度。然而,這樣做的缺點是:1)INT8 的矩陣乘法只達到了一半的 INT4 矩陣乘法的速度,2)使用 FP16 的乘法累加器的 FP16 的矩陣乘法的加速只在 RTX4090 和 RTX3090 顯示卡上有效。

為了克服上述缺點,SageAttention2 提出將 Q, K 量化至 INT4,並將 P, V 量化至 FP8 來加速 Attention。然而,這樣做的挑戰是很大的。

4-Bit 注意力量化有什麼問題?

研究團隊發現直接將注意力運算中的 Q, K 量化為 INT4 後將會導致在幾乎所有模型和任務上都會得到極差的結果,例如,在 CogVideoX 文生影片模型中,會得到完全模糊的影片;Llama2-7B 進行四選一選擇題任務上得到 25% 的準確率。
圖片
經過仔細分析後,研究團隊發現主要是兩個原因導致了量化注意力的不準確:

(1)INT4 的數值範圍相比 INT8 非常小,導致其量化誤差在 Q,K 矩陣中出現一些異常值時會變得十分明顯,恰好大多模型都在 Q, K 中表現出來了較大的通道維度的異常值。這極大削減了 QK^⊤矩陣乘法的精度。
圖片
(2)研究團隊發現 Nvidia 的顯示卡上,FP8 的矩陣乘法指令 (mma.f32.f8.f8.f32) 的乘法累加器並不是官方宣稱的 FP32 精度,而是隻有 FP22 精度,這導致了 PV 矩陣乘法出現較大的累加誤差。
圖片
技術方案

為了解決上述的兩個挑戰,研究團隊提出了對應的解決辦法。

(1)保留 SageAttention 中對 K 進行平滑處理的同時,提出對 Q 進行平滑處理:Q – mean (Q)。其中 mean (Q) 是沿著通道維度的平均值向量。完成該平滑操作後需要在 Attention 計算過程中將 mean (Q) 和 K^T 的向量與矩陣乘法的結果補償到 S 中。
圖片
這使得相比直接量化 Q, K 至 INT4 的準確度有質的改變,如下表展示了對比了該方法和直接量化 Q, K 至 INT4 在 Cogvideo 和 Llama3.1 上的端到端表現。
圖片
矩陣 Q 平滑前後的資料分佈視覺化的結果如下,可以發現平滑後的 Q 對 INT4 資料範圍的利用度更高:
圖片
(2)對 Q, K 進行 Per-thread 量化。對於矩陣 Q, K,SageAttention2 採用了根據 mma 指令對矩陣記憶體排布的要求,對 Q,K 中的 Token 按照 GPU 執行緒進行分組,使量化粒度比 SageAttention 中的 per-block 細化 16 倍,極大提高了 4Bit 的 QK^⊤乘法準確度的同時不引入任何額外開銷。

具體來說,在 SageAttention 中,每個 Q 的塊將被劃分為 c_w 個段,由 GPU 流處理器(SM)中的 c_w 個 GPU warp 處理。然後,每個包含 32 個執行緒的 warp 會使用 NVIDIA 的 mma.m16n8k64 PTX 指令來執行 QK^⊤運算。根據這一指令的佈局要求,研究團隊發現一個 warp 內的 Q [8×(n%8)] 可以共用一個量化縮放引數,而一個 warp 內的 K [8×(n%8)] 和 K [8×(n%8+1)] 也可以共用一個量化縮放引數,其中 n 是 token 索引。

這種量化方法更為細緻且不增加額外開銷。這是因為它根據 MMA 指令的佈局將不同的 GPU 執行緒分配到不同的量化 Token 組,每個執行緒只對應一個量化縮放引數進行反量化。而非 Per-token 量化那樣,每個執行緒對應多個量化縮放引數。
圖片
如下表所示,可以發現 per-thread 量化的準確度比 SageAttention 中採用的 per-block 量化高得多,準確度和 per-token 量化幾乎沒有差別。
圖片
(3)對 FP8 的 PV 矩陣乘法採用 FP32 的暫存器將每次 FlashAttention 分塊粒度的 PV 的 FP22 的乘法結果累加起來。這種做法可以有效地避免 FP22 的乘法累加器沿著序列長度累積過多的誤差,將 FP22 累加器帶來的誤差控制在 FlashAttention 分塊的粒度中,提高了 FP8 的 PV 乘法的準確度。

(4)針對 P 和 V,研究團隊對比了多種量化的資料型別,對比發現使用 E4M3 資料格式的 FP8 精度最準確,基本接近了 FP16 的準確度。因此採用將 P 和 V 量化至 E4M3。
圖片
下圖展示了 SageAttention2 的演算法流程:
圖片
SageAttention2 共實現了兩種 Kernel,區別在於對 Q, K 進行 INT4 量化還是 INT8 量化:
圖片
此外,SageAttention2 還提出一種可選的對矩陣 V 進行平滑處理的技術,可以進一步提高 PV 矩陣乘法的準確度。具體來說,當某些模型中 V 矩陣具有通道維度的偏移時,可以將 V 減去其通道維度的平均值 mean (V) 來去除偏移,之後進行正常的量化 Attention 運算。只需要對最終 Attention 的 Output 加上 mean (V) 即可保持計算的正確性。
圖片
圖片
這種做法可以提升準確度的原因如下圖所示。在 FP22 的表示範圍內,數值越大,相比 FP32 的誤差越大。而 P 的範圍是 0~1 之間,那麼當 V 矩陣的列有較大的數值偏移時,PV 的 FP22 累加器的精度就越差,透過平滑 V 去除偏移後,就可以加強 PV 矩陣乘法的準確度。
圖片
實驗效果

SageAttention 實現了底層的 GPU CUDA Kernel,在運算元速度以及各個模型端到端準確度上都有十分不錯的表現。

具體來說,運算元速度相比於 FlashAttention2 和 xformers 有大約 3 倍以及 4.5 倍的加速:
圖片
圖片
運算元的準確度方面也是比對 Q, K 進行 SmoothQuant 和 Hadamard 變換要更加準確:
圖片
各模型在真實場景的端到端精度表現中,在影片、影像、文字生成等大模型上均保持了端到端的精度表現:

下圖是在 HunyuanVideo 中的視覺化例項:
圖片
下圖是在 Cogvideo 中的視覺化例項:
圖片
下表展示了各個語言、影片、影像生成模型中 SageAttention2 的端到端精度表現:
圖片
圖片
端到端的速度表現上,SageAttention2 兩個 Kernel 的實現均可以有效地對長序列模型進行加速,比如可以端到端 1.8 倍加速 CogVideoX1.5-5B,其他模型上也均有 1.6 1.8 倍的提速。
圖片

相關文章