新PyTorch API:幾行程式碼實現不同注意力變體,兼具FlashAttention效能和PyTorch靈活性

机器之心發表於2024-08-11

用 FlexAttention 嘗試一種新的注意力模式。


理論上,注意力機制就是你所需要的一切。然而在實際操作中,我們還需要最佳化像 FlashAttention 這樣的注意力機制的實現。

儘管這些融合的注意力機制大大提高了效能,且支援長上下文,但這種效率的提升也伴隨著靈活性的喪失。對於機器學習研究人員來說,這就像是一種「軟體彩票」—— 如果你的注意力變體不適合現有的最佳化核心,你將面臨執行緩慢和 CUDA 記憶體不足的困境。

一些注意力變體包括因果注意力、相對位置嵌入、Alibi、滑動視窗注意力、PrefixLM、文件掩碼、不規則張量、PagedAttention 等。更糟糕的是,人們通常希望將這些變體組合在一起!比如滑動視窗注意力 + 文件掩碼 + 因果注意力 + 上下文並行,又比如 PagedAttention + 滑動視窗的組合。

下圖左側代表了當今的現狀 —— 一些掩碼 + 偏置 + 設定的組合已經有現成的核心實現。然而,各種選項的新增會導致設定呈指數級增長。更糟糕的是,這種方式不會支援新的注意力變體。
圖片
為了徹底地解決這個超立方體問題,PyTorch 團隊引入了 FlexAttention,一個新的 PyTorch API。

  1. FlexAttention 是一個靈活的 API,允許使用者使用幾行慣用的 PyTorch 程式碼就能實現多個注意力變體。
  2. 團隊人員透過 torch.compile 將其降低到一個融合的 FlashAttention 核心中 ,生成了一個不會佔用額外記憶體且效能可與手寫核心相媲美的 FlashAttention 核心。
  3. 利用 PyTorch 的自動求導機制自動生成反向傳播。
  4. 最後,PyTorch 團隊還可以利用注意力掩碼中的稀疏性,從而顯著改善標準注意力實現。
圖片
FlashAttention 1-3 版本的參與者 Tri Dao 對這項研究進行了轉發並評論:這項研究使得很多技術都融合在一起了。
圖片
FlexAttention

經典的注意力方程式如下:
圖片
程式碼形式:
圖片
FlexAttention 形式如下,其透過接受使用者定義的函式 score_mod 來解決上述問題。
圖片程式碼形式:
圖片
此函式允許使用者在 softmax 之前修改注意力分數。研究人員發現,該函式最終足以滿足大多數使用者對注意力變體的需求。

具體而言,score_mod 如下:
圖片
要應用此函式,可以將其實現為:
for b in range (batch_size):
    for h in range (num_heads):
        for q_idx in range (sequence_length):
            for kv_idx in range (sequence_length):
                modified_scores [b, h, q_idx, kv_idx] = score_mod (scores [b, h, q_idx, kv_idx], b, h, q_idx, kv_idx)

最終的 API 具有令人驚訝的表達能力。

Score Mod 示例

全注意力

在這種情況下,score_mod 無操作,它接受分數作為輸入,然後原樣返回它們。
圖片
然後端到端的使用。
圖片
相對位置編碼

一種常見的注意力變體是相對位置編碼。相對位置編碼不是對查詢和鍵中的絕對距離進行編碼,而是根據查詢和鍵之間的距離調整分數。
圖片
需要注意的是,與典型實現不同,這不需要具體化 SxS 張量。相反,FlexAttention 會在核心中動態計算偏差值,從而顯著提高記憶體和效能。
圖片
Soft-capping

Soft-capping 是 Gemma 2 和 Grok-1 使用的一種技術,在 FlexAttention 中,它的形式是這樣的:
圖片
Causal Mask

儘管雙向注意力很簡單,但在論文《Attention is All You Need》,以及其他的 LLM 中,它們的設定都是僅解碼器的注意力,其中每個 token 只能關注它之前的 token。如果使用者使用 score_mod API ,可以將其表示為:
圖片
Sliding Window + Causal
圖片
圖源:https://arxiv.org/abs/2310.06825

Mistral 一直在推廣滑動視窗注意力(也稱為區域性注意力),它允許查詢 token 僅關注最近的 1024 個 token,通常與因果注意力一起使用。
圖片
研究者對帶有滑動視窗掩碼的 F.scaled_dot_product_attention 以及帶有因果掩碼的 FA2 進行基準測試。結果表明,FlexAttention 不僅明顯快於 F.scaled_dot_product_attention,也明顯快於帶有因果掩碼的 FA2。
圖片
效能

總體而言,FlexAttention 的效能幾乎與手寫的 Triton 核心一樣好。然而,由於 FlexAttention 具有通用性,因此會遭受輕微的效能損失。例如,使用者必須承受一些額外的延遲。

FlexAttention 在前向傳播中實現了 FlashAttention2 效能的 90%,在反向傳播中實現了 85%。FlexAttention 目前正在使用一種確定性演算法,該演算法比 FAv2 重新計算了更多的中間體,研究者計劃改進 FlexAttention 的反向演算法,來縮小這一差距!
圖片
圖片
參考連結:https://pytorch.org/blog/flexattention/

相關文章