FlashAttention逐代解析與公式推導

sasasatori發表於2024-10-18

Standard Attention

標準Attention計算可以簡化為:

\[O = softmax(QK^T)V \tag{1} \]

此處忽略了Attention Mask和維度歸一化因子\(1/\sqrt{d}\)

公式(1)的標準計算方式是分解成三步:

\[S = QK^T \tag{2} \]

\[P=softmax(S) \tag{3} \]

\[O = PV \tag{4} \]

但這樣做的問題在於,假設\(Q,K,V \in R^{N\times d}\),其中\(N\)為序列長度,\(D\)為注意力頭的維度,那麼輸出\(O \in R^{N\times d}\)\(S,P\in R^{N\times N}\)。由於在標準實現下,\(S,P\)都需要從HBM中讀寫,因此構成了\(O(N^2)\)的記憶體複雜度。一般情況下\(N\gg D\),例如GPT-2中,\(N=1024\)\(d=64\),因此\(S\)\(P\)\(O(N^2)\)視訊記憶體開銷是遠大於\(Q,K,V,O\)\(O(Nd)\)的。

一個樸素的想法是:我們能否不進行\(O(N^2)\)的HBM讀寫,透過避免這一頻繁的讀寫操作來大大提升Attention的計算效率。

Online Softmax

假設我們不將\(S,P\)寫回HBM,那麼就得將其放在片上SRAM中。但是這裡的問題是片上SRAM受限於容量,一般無法一次性完整的計算Attention,因此我們必須採用分塊(Tiling)操作,使得分塊後的記憶體需求不超過SRAM的大小。但計算softmax的時候,其歸一化因子(分母)需要所有的輸入資料,因此進行分塊計算的難度較大。

考慮softmax的公式,對於輸入序列\(x\)

\[x=[x_1,x_2,...,x_d] \tag{5} \]

原生softmax函式為:

\[softmax(x_i)=\frac{e^{x_i}}{\sum_{j=1}^{d}e^{x_j}} \tag{6} \]

為了避免數值溢位的問題,現在一般採用safe softmax的方式,即定義:

\[m(x)=max([x_1,x_2,...,x_d]) \tag{7} \]

safe softmax函式在e指數上減去\(m(x)\),使得所有的e指數項的值分佈在0到1之間(因為\(x_i-m(x)\leq 0\)),從而規避數值溢位的問題,此外還能提升數值穩定性,加快計算速度。改造後的函式為:

\[softmax(X)=\frac{e^{x_i-m(x)}}{\sum_{j=1}^{d}e^{x_j-m(x)}} \tag{8} \]

接下來我們需要研究如何對safe softmax應用分塊策略來計算,即所謂的online softmax。

標準的softmax情況下,演算法為:

for i = 1 to N do:

\[m_i\leftarrow max(m_{i-1},x_i) \tag{9} \]

end

for i = 1 to N do:

\[d_i \leftarrow d_{i-1}+e^{x_i-m_N} \tag{10} \]

end

for i = 1 to N do:

\[a_i \leftarrow \frac{e^{x_i-m_N}}{d_N} \tag{11} \]

end

可以看到這個計算過程需要三次1到N的迴圈,在Attention中,這裡的\(x_i\)來自於\(QK^T\),由於我們沒辦法在SRAM中裝下\(Q\)\(K\),因此我們需要從記憶體中訪問他們三次。假如我們能夠想辦法將\((9)\)\((11)\)放到一個迴圈中,我們就能將訪存從三次減少到一次。然而,由於\((9)\)\((10)\)之間存在依賴,因為\((10)\)中包含一個不到最後一次迴圈就無法獲知的\(m_N\),因此我們很難將它們合併起來。

我們可以構造一個\(d_i^{'}=\sum_{j=1}^ie^{x_j-m_i}\)替代原有的\(d_i=\sum_{j=1}^ie^{x_j-m_N}\)以取消其對\(m_N\)的全域性依賴,並且只要一達到\(i=N\),我們自然而然的就有\(d_N^{'}=d_N\),因此我們可以用\(d_N^{'}\)來替換\((11)\)中的\(d_N\)。並且我們可以求得\(d_{i}^{'}\)\(d_{i-1}^{'}\)之間的遞推關係:

\[d_i^{'}=\sum_{j=1}^ie^{x_j-m_i} =(\sum_{j=1}^{i-1}e^{x_j-m_i})+e^{x_i-m_i}=(\sum_{j=1}^{i-1}e^{x_j-m_{i-1}})e^{m_{i-1}-m_i}+e^{x_i-m_i}=d_{i-1}^{'}e^{m_{i-1}-m_i}+e^{x_i-m_i} \tag{12} \]

可以看到這裡的公式依賴於\(m_{i-1}\)\(m_i\)。因此我們可以把\((9)\)\((10)\)放進一個迴圈中:

for i = 1 to N do:

\[m_i\leftarrow max(m_{i-1},x_i) \tag{13} \]

\[d_i^{'} \leftarrow d_{i-1}^{'}e^{m_{i-1}-m_i}+e^{x_i-m_i} \tag{14} \]

end

for i = 1 to N do:

\[a_i \leftarrow \frac{e^{x_i-m_N}}{d_N^{'}} \tag{15} \]

end

這樣我們實現了3次迴圈到2次迴圈的合併,從而減少了1/3的記憶體訪問。但是我們能否進一步直接合併到一步迴圈內呢。對於softmax來說,很不幸的是不可能的。但對於我們要求的Attention來說,這是可以實現的。

FlashAttention V1

對於Attention來說,我們最終要獲得的並非是softmax後得出的矩陣\(P\),而是輸出矩陣\(O=PV\),因此我們的目標是嘗試找到一個一步迴圈求得\(O\)​的方法。

我們先來看應用了online softmax的Attention計算過程:

for i = 1 to N do:

\[x_i \leftarrow Q[k,:]K^T[:,i] \]

\[m_i\leftarrow max(m_{i-1},x_i) \]

\[d_i^{'} \leftarrow d_{i-1}^{'}e^{m_{i-1}-m_i}+e^{x_i-m_i} \]

end

for i = 1 to N do:

\[a_i \leftarrow \frac{e^{x_i-m_N}}{d_N^{'}} \tag{16} \]

\[o_i\leftarrow o_{i-1}+a_iV[i,:] \tag{17} \]

end

\[O[k,:]\leftarrow o_N \]

我們將\((17)\)中的\(a_i\)替換成定義式\((16)\),從而有:

\[o_i=(\sum_{j=1}^i\frac{e^{x_j-m_N}}{d_N^{'}}V[j,:]) \tag{18} \]

這裡可以看到依賴於兩個全域性值\(m_N\)\(d_N^{'}\)。我們可以應用和online softmax推導時類似的技巧,先構造一個\(o_i^{'}\)

\[o_i^{'}=(\sum_{j=1}^i\frac{e^{x_j-m_i}}{d_i^{'}}V[j,:]) \]

只要達到\(i=N\),我們就有\(o_N^{'}=o_N\),並且我們可以求出一個\(o_{i-1}^{'}\)\(o_i^{'}\)之間的遞推公式:

\[o_i^{'}=(\sum_{j=1}^i\frac{e^{x_j-m_i}}{d_i^{'}}V[j,:])=(\sum_{j=1}^{i-1}\frac{e^{x_j-m_i}}{d_i^{'}}V[j,:])+\frac{e^{x_i-m_i}}{d_i^{'}}V[i,:]\\ =(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d_{i-1}^{'}}\frac{e^{x_j-m_i}}{e^{x_j-m_{i-1}}}\frac{d_{i-1}^{'}}{d_i^{'}}V[j,:])+\frac{e^{x_i-m_i}}{d_i^{'}}V[i,:] \\ =(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d_{i-1}^{'}}V[j,:])\frac{d_{i-1}^{'}}{d_i^{'}}e^{m_{i-1}-m_i}+\frac{e^{x_i-m_i}}{d_i^{'}}V[i,:] \\ = o_{i-1}^{'}\frac{d_{i-1}^{'}}{d_i^{'}}e^{m_{i-1}-m_i}+\frac{e^{x_i-m_i}}{d_i^{'}}V[i,:] \tag{19} \]

可以看到這裡不再依賴任何一個全域性值,因此我們可以得到Flash Attention的演算法:

for i = 1 to N do:

\[x_i \leftarrow Q[k,:]K^T[:,i] \]

\[m_i\leftarrow max(m_{i-1},x_i) \]

\[d_i^{'} \leftarrow d_{i-1}^{'}e^{m_{i-1}-m_i}+e^{x_i-m_i} \]

\[o_i^{'}=o_{i-1}^{'}\frac{d_{i-1}^{'}}{d_i^{'}}e^{m_{i-1}-m_i}+\frac{e^{x_i-m_i}}{d_i^{'}}V[i,:] \]

end

\[O[k,:]\leftarrow o_N^{'} \]

我們可以進一步對這個演算法應用分塊(tiling),假定tile的大小為\(b\),共分塊\(\#tiles\)個。那麼\(x_i\)為儲存\([(i-1)b:ib]\)\(QK^T\)值的向量。\(m_i^{(local)}\)為向量\(x_i\)的區域性最大值。那麼對於每個tile,有:

for i = 1 to #tiles do:

\[x_i \leftarrow Q[k,:]K^T[:,(i-1)b:ib] \]

\[m_i^{(local)}\leftarrow max_{j=1}^b(x_i[j]) \]

\[m_i \leftarrow max(m_{i-1},m_i^{(local)}) \]

\[d_i^{'} \leftarrow d_{i-1}^{'}e^{m_{i-1}-m_i}+\sum_{j=1}^b e^{x_i[j]-m_i} \]

\[o_i^{'}=o_{i-1}^{'}\frac{d_{i-1}^{'}}{d_i^{'}}e^{m_{i-1}-m_i}+\sum_{j=1}^{b}\frac{e^{x_i[j]-m_i}}{d_i^{'}}V[j+(i-1)b,:] \]

end

\[O[k,:]\leftarrow o_{N/b}^{'} \]

形象的理解如下圖所示:

image

最後我們來看效果,由於\(S\)\(P\)的計算完全在SRAM上完成(之前做不到的原因在這節開頭時說了,想要完整的把\(S\)\(P\)放上去,片上SRAM的容量不夠,但是採用分塊迭代策略後就ok了)而不需要對HBM做寫回。因此在Standard Attention一節我們分析的,\(O(N^2)\)\(S\)\(P\)的HBM讀寫開銷就沒有了,只有\(Q\)\(K\)\(V\)\(O\)\(O(Nd)\)的開銷,但我們之前也分析過,由於\(N\gg d\),所以\(N^2\gg Nd\),我們可以進一步的當作現在的視訊記憶體開銷變成了只有與\(N\)線性相關,而非二次相關的\(O(N)\)。從\(O(N^2)\)\(O(N)\),這顯然是一個非常顯著的改進。

FlashAttention V2

在V1的基礎上,我們來看V2的一個insight。從硬體的角度來說,GPU計算矩陣乘加的算力是遠高於其他的運算的。具體來說,以A100為例,FP16/BF16的矩陣乘法可以達到312TFLOPS,但是對於非矩陣乘法的FP32,其算力只有19.5TFLOPS,差了一個數量級(16x)。因此一個明顯的改進思路是減少FlashAttention中的非矩陣乘加運算。

觀察公式\((19)\),一個切入點是每個迴圈計算\(O\)時進行了兩次除法,即:

\[o_i^{'}=o_{i-1}^{'}\frac{d_{i-1}^{'}}{d_i^{'}}e^{m_{i-1}-m_i}+\frac{e^{x_i-m_i}}{d_i^{'}}V[i,:] \]

兩項都需要除以\(d_i^{'}\)。因此相當於是進行了2N次的除法。但實際上這個除法操作可以提取到迴圈外,即每次更新\(o_i^{'}\)時,採用:

\[\widetilde{o}_i^{'}=\widetilde{o}_{i-1}^{'}e^{m_{i-1}-m_i}+e^{x_i-m_i}V[i,:] \]

因此每次更新時可以只維護未縮放的\(\widetilde{o}_i^{'}\)。當\(i=N\)時,利用\(o_N^{'}=\widetilde{o}_N^{'}/d_N^{'}\)​,可以將之前每次迴圈中的2N次除法提出,變成迴圈結束後進行一次除法,從而大大減少除法的計算量(從2N次變為1次)。

即:

for i = 1 to N do:

\[x_i \leftarrow Q[k,:]K^T[:,i] \]

\[m_i\leftarrow max(m_{i-1},x_i) \]

\[d_i^{'} \leftarrow d_{i-1}^{'}e^{m_{i-1}-m_i}+e^{x_i-m_i} \]

\[\widetilde{o}_i^{'}=\widetilde{o}_{i-1}^{'}e^{m_{i-1}-m_i}+e^{x_i-m_i}V[i,:] \]

end

\[o_N^{'}=\frac{\widetilde{o}_N^{'}}{d_N^{'}} \]

\[O[k,:]\leftarrow o_N^{'} \]

最本質的原因其實在於在迭代計算時,實際上每一次\(o_i^{'}\)的縮放項\(d_{i-1}^{'}/d_i^{'}\)都可以把上一次\(o_{i-1}^{'}\)的共分母\(d_{i-1}^{'}\)​給吸收掉。因此也可以在迭代時直接丟棄這個冗餘的運算(不妨聯想一下反向傳播的鏈式法則,有一定的相似性)。

V2為了應對訓練時的需求,在前向計算的迴圈中也會暫存維護一個變數,不過我們這裡不做詳細討論。此外V2在演算法上也根據GPU特性更改了內外層迴圈的順序來提高並行度,但這裡就不去做詳細介紹了,可以看論文以及其他的部落格理解。

FlashAttention V3

現在來看V3。在V2的基礎上,為了提升Flash Attention演算法在H100 GPU上的利用率,V3做了幾件事,首先將GEMM操作以Producer & Consumer的形式進行了非同步化,隨後透過Ping-Pong操作將softmax操作隱藏到GEMM操作中(GEMM-softmax流水線),最後應用了更低精度的FP8數制GEMM操作來實現效能提升。

Producer和Consumer的理解其實很簡單,Producer的目的是從HBM中載入計算所需的\(Q\)\(K\)\(V\),而Consumer的內容和V2的公式完全一樣,主要起到消耗掉Producer提供的\(Q\)\(K\)\(V\)並計算\(O\)然後寫回。透過Ping-Pong排程這兩個部分,可以把慢速的softmax操作隱藏到分段的GEMM操作中。具體來說,以下圖為例,當一個Warpgroup在進行GEMM操作時,另一個Warpgroup在進行前一批GEMM操作後的softmax操作中去。

image

更一步的,在一個Warpgroup中,我們可以將一些softmax的指令與GEMM的指令進行並行來進一步提高吞吐率。如下圖所示,可以將一些Softmax的指令隱藏到GEMM的指令執行時間中去。

image

具體的演算法上和V2實際上沒有發生什麼變化。

參考文獻

From Online Softmax to FlashAttention

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

相關文章