AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com
論文第一作者張金濤來自清華大學計算機系,論文通訊作者陳鍵飛副教授及其他合作作者均來自清華大學計算機系。
大模型中,線性層的低位元量化(例如 INT8, INT4)已經逐步落地;對於注意力模組,目前幾乎各個模型都還在用高精度(例如 FP16 或 FP32)的注意力運算進行訓練和推理。然而,隨著大型模型需要處理的序列長度不斷增加,Attention(注意力運算)的時間開銷逐漸成為網路最佳化的主要瓶頸。
為了提高注意力運算的效率,清華大學陳鍵飛團隊提出了 8Bit 的 Attention(SageAttention)。實現了 2 倍以及 2.7 倍相比於 FlashAttention2 和 xformers 的即插即用的推理加速,且在影片、影像、文字生成等大模型上均沒有端到端的精度損失。
論文標題:SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration
論文連結:https://arxiv.org/abs/2410.02367
開原始碼:https://github.com/thu-ml/SageAttention
大多影片、影像生成模型中,矩陣 K 表現出了極強的通道維度的異常值分佈,直接使用 INT8 或者 FP8 資料型別對其進行量化會導致巨大的誤差。
在所有模型中,對矩陣 P, V 進行量化不能保證一個模型中所有層的精度。下表展示了對 P, V 量化後,Llama2-7B 和 Unidiffuser 模型所有層中,最差情況的層對應的量化注意力的準確度,(該準確度為量化注意力相比全精度注意力的誤差),可以發現不管對 P, V 矩陣進行何種 8Bit (INT8,E4M3,E5M2)量化,總有些層的準確率非常差,導致了端到端效果的下降。
對 K 進行平滑處理。SageAttention 採用了一個簡單但非常實用的方法來消除矩陣 K 的異常值:K = K – mean (K) 其中 mean (K) 是沿著通道維度求平均值。這個簡單的做法不僅不會影響注意力計算的正確性 Softmax (QK^T) = Softmax (Q (K-mean (K))^T) ;且對整個 Attention 速度的影響只有 0.2%;同時還保證了量化後的注意力運算的精度:
對 Q, K 進行分塊 INT8 量化。對於矩陣 Q, K,SageAttention 採用了以 FlashAttention 的分塊大小為粒度的 INT8 量化。這是因為:1. 對 Q, K 矩陣進行 INT8 量化相比於進行 FP8 量化,注意力的精度更高。2. 在一些常用卡上,比如 RTX4090,INT8 矩陣乘法(INT32 為累加器)的速度是 FP8(FP32 為累加器)的兩倍。
對 P, V 採用 FP16 資料型別的矩陣乘法累加器。對於矩陣 P, V,SageAttention 採用了保留 P, V 為 FP16 的型別,但進行矩陣乘法時採用 FP16 資料型別的累加器。這是因為:1. PV 矩陣乘法的數值範圍始終在 FP16 的表示範圍內,且經過大量實驗驗證,FP16 作為累加器的資料型別不會帶來任何精度損失(見下表)。2. 在一些常用卡上,比如 RTX4090,以 FP16 為累加器資料型別的矩陣乘法的速度是 FP32 作為累加器的兩倍。