Standard Attention
標準Attention計算可以簡化為:
此處忽略了Attention Mask和維度歸一化因子\(1/\sqrt{d}\)。
公式(1)的標準計算方式是分解成三步:
但這樣做的問題在於,假設\(Q,K,V \in R^{N\times d}\),其中\(N\)為序列長度,\(D\)為注意力頭的維度,那麼輸出\(O \in R^{N\times d}\),\(S,P\in R^{N\times N}\)。由於在標準實現下,\(S,P\)都需要從HBM中讀寫,因此構成了\(O(N^2)\)的記憶體複雜度。一般情況下\(N\gg D\),例如GPT-2中,\(N=1024\),\(d=64\),因此\(S\)和\(P\)的\(O(N^2)\)視訊記憶體開銷是遠大於\(Q,K,V,O\)的\(O(Nd)\)的。
一個樸素的想法是:我們能否不進行\(O(N^2)\)的HBM讀寫,透過避免這一頻繁的讀寫操作來大大提升Attention的計算效率。
Online Softmax
假設我們不將\(S,P\)寫回HBM,那麼就得將其放在片上SRAM中。但是這裡的問題是片上SRAM受限於容量,一般無法一次性完整的計算Attention,因此我們必須採用分塊(Tiling)操作,使得分塊後的記憶體需求不超過SRAM的大小。但計算softmax的時候,其歸一化因子(分母)需要所有的輸入資料,因此進行分塊計算的難度較大。
考慮softmax的公式,對於輸入序列\(x\):
原生softmax函式為:
為了避免數值溢位的問題,現在一般採用safe softmax的方式,即定義:
safe softmax函式在e指數上減去\(m(x)\),使得所有的e指數項的值分佈在0到1之間(因為\(x_i-m(x)\leq 0\)),從而規避數值溢位的問題,此外還能提升數值穩定性,加快計算速度。改造後的函式為:
接下來我們需要研究如何對safe softmax應用分塊策略來計算,即所謂的online softmax。
標準的softmax情況下,演算法為:
for i = 1 to N do:
end
for i = 1 to N do:
end
for i = 1 to N do:
end
可以看到這個計算過程需要三次1到N的迴圈,在Attention中,這裡的\(x_i\)來自於\(QK^T\),由於我們沒辦法在SRAM中裝下\(Q\)和\(K\),因此我們需要從記憶體中訪問他們三次。假如我們能夠想辦法將\((9)\)到\((11)\)放到一個迴圈中,我們就能將訪存從三次減少到一次。然而,由於\((9)\)和\((10)\)之間存在依賴,因為\((10)\)中包含一個不到最後一次迴圈就無法獲知的\(m_N\),因此我們很難將它們合併起來。
我們可以構造一個\(d_i^{'}=\sum_{j=1}^ie^{x_j-m_i}\)替代原有的\(d_i=\sum_{j=1}^ie^{x_j-m_N}\)以取消其對\(m_N\)的全域性依賴,並且只要一達到\(i=N\),我們自然而然的就有\(d_N^{'}=d_N\),因此我們可以用\(d_N^{'}\)來替換\((11)\)中的\(d_N\)。並且我們可以求得\(d_{i}^{'}\)和\(d_{i-1}^{'}\)之間的遞推關係:
可以看到這裡的公式依賴於\(m_{i-1}\)和\(m_i\)。因此我們可以把\((9)\)和\((10)\)放進一個迴圈中:
for i = 1 to N do:
end
for i = 1 to N do:
end
這樣我們實現了3次迴圈到2次迴圈的合併,從而減少了1/3的記憶體訪問。但是我們能否進一步直接合併到一步迴圈內呢。對於softmax來說,很不幸的是不可能的。但對於我們要求的Attention來說,這是可以實現的。
FlashAttention V1
對於Attention來說,我們最終要獲得的並非是softmax後得出的矩陣\(P\),而是輸出矩陣\(O=PV\),因此我們的目標是嘗試找到一個一步迴圈求得\(O\)的方法。
我們先來看應用了online softmax的Attention計算過程:
for i = 1 to N do:
end
for i = 1 to N do:
end
我們將\((17)\)中的\(a_i\)替換成定義式\((16)\),從而有:
這裡可以看到依賴於兩個全域性值\(m_N\)和\(d_N^{'}\)。我們可以應用和online softmax推導時類似的技巧,先構造一個\(o_i^{'}\):
只要達到\(i=N\),我們就有\(o_N^{'}=o_N\),並且我們可以求出一個\(o_{i-1}^{'}\)到\(o_i^{'}\)之間的遞推公式:
可以看到這裡不再依賴任何一個全域性值,因此我們可以得到Flash Attention的演算法:
for i = 1 to N do:
end
我們可以進一步對這個演算法應用分塊(tiling),假定tile的大小為\(b\),共分塊\(\#tiles\)個。那麼\(x_i\)為儲存\([(i-1)b:ib]\)的\(QK^T\)值的向量。\(m_i^{(local)}\)為向量\(x_i\)的區域性最大值。那麼對於每個tile,有:
for i = 1 to #tiles do:
end
形象的理解如下圖所示:
最後我們來看效果,由於\(S\)和\(P\)的計算完全在SRAM上完成(之前做不到的原因在這節開頭時說了,想要完整的把\(S\),\(P\)放上去,片上SRAM的容量不夠,但是採用分塊迭代策略後就ok了)而不需要對HBM做寫回。因此在Standard Attention一節我們分析的,\(O(N^2)\)的\(S\),\(P\)的HBM讀寫開銷就沒有了,只有\(Q\),\(K\),\(V\),\(O\)的\(O(Nd)\)的開銷,但我們之前也分析過,由於\(N\gg d\),所以\(N^2\gg Nd\),我們可以進一步的當作現在的視訊記憶體開銷變成了只有與\(N\)線性相關,而非二次相關的\(O(N)\)。從\(O(N^2)\)到\(O(N)\),這顯然是一個非常顯著的改進。
FlashAttention V2
在V1的基礎上,我們來看V2的一個insight。從硬體的角度來說,GPU計算矩陣乘加的算力是遠高於其他的運算的。具體來說,以A100為例,FP16/BF16的矩陣乘法可以達到312TFLOPS,但是對於非矩陣乘法的FP32,其算力只有19.5TFLOPS,差了一個數量級(16x)。因此一個明顯的改進思路是減少FlashAttention中的非矩陣乘加運算。
觀察公式\((19)\),一個切入點是每個迴圈計算\(O\)時進行了兩次除法,即:
兩項都需要除以\(d_i^{'}\)。因此相當於是進行了2N次的除法。但實際上這個除法操作可以提取到迴圈外,即每次更新\(o_i^{'}\)時,採用:
因此每次更新時可以只維護未縮放的\(\widetilde{o}_i^{'}\)。當\(i=N\)時,利用\(o_N^{'}=\widetilde{o}_N^{'}/d_N^{'}\),可以將之前每次迴圈中的2N次除法提出,變成迴圈結束後進行一次除法,從而大大減少除法的計算量(從2N次變為1次)。
即:
for i = 1 to N do:
end
最本質的原因其實在於在迭代計算時,實際上每一次\(o_i^{'}\)的縮放項\(d_{i-1}^{'}/d_i^{'}\)都可以把上一次\(o_{i-1}^{'}\)的共分母\(d_{i-1}^{'}\)給吸收掉。因此也可以在迭代時直接丟棄這個冗餘的運算(不妨聯想一下反向傳播的鏈式法則,有一定的相似性)。
V2為了應對訓練時的需求,在前向計算的迴圈中也會暫存維護一個變數,不過我們這裡不做詳細討論。此外V2在演算法上也根據GPU特性更改了內外層迴圈的順序來提高並行度,但這裡就不去做詳細介紹了,可以看論文以及其他的部落格理解。
FlashAttention V3
現在來看V3。在V2的基礎上,為了提升Flash Attention演算法在H100 GPU上的利用率,V3做了幾件事,首先將GEMM操作以Producer & Consumer的形式進行了非同步化,隨後透過Ping-Pong操作將softmax操作隱藏到GEMM操作中(GEMM-softmax流水線),最後應用了更低精度的FP8數制GEMM操作來實現效能提升。
Producer和Consumer的理解其實很簡單,Producer的目的是從HBM中載入計算所需的\(Q\),\(K\),\(V\),而Consumer的內容和V2的公式完全一樣,主要起到消耗掉Producer提供的\(Q\),\(K\),\(V\)並計算\(O\)然後寫回。透過Ping-Pong排程這兩個部分,可以把慢速的softmax操作隱藏到分段的GEMM操作中。具體來說,以下圖為例,當一個Warpgroup在進行GEMM操作時,另一個Warpgroup在進行前一批GEMM操作後的softmax操作中去。
更一步的,在一個Warpgroup中,我們可以將一些softmax的指令與GEMM的指令進行並行來進一步提高吞吐率。如下圖所示,可以將一些Softmax的指令隱藏到GEMM的指令執行時間中去。
具體的演算法上和V2實際上沒有發生什麼變化。
參考文獻
From Online Softmax to FlashAttention
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision