英偉達又賺到了!FlashAttention3來了:H100利用率飆升至75%

机器之心發表於2024-07-12

740 TFLOPS!迄今最強 FlashAttention 來了。

隨著大型語言模型(LLM)加速落地,擴充套件模型上下文視窗變得越來越重要。然而,Transformer 架構的核心 —— 注意力層的時間複雜度和空間複雜度與輸入序列長度的平方成正比。這使得擴充套件模型上下文視窗存在挑戰。

2022 年,一種快速、記憶體高效的注意力演算法 ——FlashAttention 問世,該演算法無需任何近似即可加速注意力並減少記憶體佔用。

FlashAttention 對注意力計算進行重新排序的演算法,並利用 tiling 和重計算來顯著加快計算速度,將記憶體使用量從序列長度的二次減少到線性。

圖片

2023 年,研究團隊宣佈推出 FlashAttention-2,在演算法、並行化和工作分割槽等方面有了顯著改進。

現在,來自 Meta、英偉達、Together AI 等機構的研究者宣佈推出 FlashAttention-3,它採用了加速 Hopper GPU 注意力的三種主要技術:

  • 透過 warp-specialization 重疊整體計算和資料移動;

  • 交錯分塊 matmul 和 softmax 運算;

  • 利用硬體支援 FP8 低精度的不連貫處理。

FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍,高達 740 TFLOPS,即 H100 理論最大 FLOPS 利用率為 75%。使用 FP8,FlashAttention-3 的速度更是接近 1.2 PFLOPS。

FlashAttention-3 的改進將帶來:

  • 更高效的 GPU 利用率:H100 理論最大 FLOPS 利用率為 75%,而之前僅為 35%。這使得 LLM 的訓練和執行速度比以前的版本快得多。

  • 較低精度下更好的效能:FlashAttention-3 可以在保持精度的同時使用較低精度的數字 (FP8)。這可以實現更快的處理速度並可能降低記憶體使用量,從而為執行大規模人工智慧操作的客戶節省成本並提高效率。

  • 能夠在 LLM 中使用更長的上下文:透過加速注意力機制,FlashAttention-3 使 AI 模型能夠更有效地處理更長的文字片段。這使得應用程式能夠理解並生成更長、更復雜的內容而不會減慢速度。

圖片

論文標題:FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

論文地址:https://tridao.me/publications/flash3/flash3.pdf

論文作者之一 、FlashAttention1-3 版本的參與者 Tri Dao 表示:FlashAttention 被廣泛用於加速 Transformers,已經使注意力速度提高了 4-8 倍,但尚未利用現代 GPU。因而他們釋出了 FlashAttention-3:在 FP16 上速度提高了 1.5-2 倍,在 H100 上高達 740 TFLOPS(75% 實用性),FP8 接近 1.2 PFLOPS!

圖片

Hopper GPU 硬體特性:WGMMA、TMA、FP8

雖然 FlashAttention-2 在 Ampere (A100) GPU 上可以實現 70% 的理論最大 FLOPS,但它尚未利用 Hopper GPU 上的新功能來最大限度地提高效能。接下來文章描述了一些新的 Hopper 特定功能,以及它們為何如此重要。

首先是 WGMMA(Warpgroup Matrix Multiply-Accumulate),該功能利用了 Hopper 架構上新的張量核心,比 Ampere 架構具有更高的吞吐量。

圖片

然後是 TMA(Tensor Memory Accelerator),這是一個特殊的硬體單元,可以加速全域性記憶體和共享記憶體之間的資料傳輸,用於處理所有索引計算和邊界外預測。這樣一來暫存器就釋放了,暫存器是增加 tile 大小和效率的寶貴資源。

圖片

低精度 FP8,讓 Tensor Core 吞吐量翻了一倍。

圖片

FlashAttention-3 充分利用了 Hopper 架構的所有這些新功能。

非同步:GEMM 和 Softmax 重疊

注意力機制主要有兩個操作,GEMM 和 softmax。為什麼要將它們重疊?

問題在於在現代加速器上,非矩陣乘法(matmul)運算比矩陣乘法運算慢。特殊函式如指數運算(如 softmax 函式)的吞吐量甚至低於浮點乘加操作;這些運算是由多功能單元處理的,這是一個與浮點乘加或矩陣乘加不同的單元。

理想情況下,研究者希望矩陣乘法和 softmax 能夠並行操作。當 Tensor Cores 忙於矩陣乘法時,多功能單元應當在計算指數運算!

Inter-warpgroup 重疊

重疊 GEMM 和 softmax 最簡單的方法是什麼都不做,warp 排程程式會免費完成部分重疊。下圖說明了 pingpong 排程,其中相同的顏色表示相同的迭代。

圖片

Intra-warpgroup 重疊

即使在一個 warpgroup 中,研究者也可以在執行該 warpgroup 的 GEMM 時執行 softmax 的某些部分。如圖所示,相同的顏色表示相同的迭代。

圖片

這種 pipeline 流程可以將 FP16 注意力前向傳播的吞吐量從大約 620 TFLOPS 提高到 640-660 TFLOPS,但代價是更高的暫存器壓力,因而需要更多的暫存器來同時儲存 GEMM 的累加器以及 Softmax 的輸入 / 輸出。

低精度:使用非相干處理減少量化誤差

啟用 LLM 可能存在一些極端值,導致量化困難,從而產生較大的量化誤差。本文采用非相干處理(incoherent processing),該技術透過將查詢和鍵與一個隨機正交矩陣相乘來「分散(spread out)」極端值,從而減少量化誤差。特別地,該研究使用了 Hadamard 變換,它可以在每個注意力頭中以 O (d log d) 的時間複雜度完成,而不是 O (d^2),其中 d 是頭部維度。

研究者發現非相干處理可以將量化誤差減少很多,具體的數值誤差比較見下表。

圖片

實驗

文中展示了 FlashAttention-3 的一些結果,並將其與 FlashAttention-2 以及 Triton 和 cuDNN 中的實現進行了比較(兩者都已經使用了 Hopper GPU 的新硬體功能)。

在 FP16 精度下,FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍。

圖片

對於 FP8,FlashAttention-3 接近 1.2 PFLOPS。

圖片

擴充套件閱讀:

史丹佛提出新型Attention演算法!提速2-4倍,BERT單節點訓練最快

比標準Attention提速5-9倍,大模型都在用的FlashAttention v2來了

參考連結:

https://tridao.me/blog/2024/flash3/

相關文章