天下苦英偉達久矣!PyTorch官方免CUDA加速推理,Triton時代要來?

机器之心發表於2024-09-09
近日,PyTorch 官方分享瞭如何實現無 CUDA 計算,對各個核心進行了微基準測試比較,並討論了未來如何進一步改進 Triton 核心以縮小與 CUDA 的差距。

在做大語言模型(LLM)的訓練、微調和推理時,使用英偉達的 GPU 和 CUDA 是常見的做法。在更大的機器學習程式設計與計算範疇,同樣嚴重依賴 CUDA,使用它加速的機器學習模型可以實現更大的效能提升。

雖然 CUDA 在加速計算領域佔據主導地位,併成為英偉達重要的護城河之一。但其他一些工作的出現正在向 CUDA 發起挑戰,比如 OpenAI 推出的 Triton,它在可用性、記憶體開銷、AI 編譯器堆疊構建等方面具有一定的優勢,並持續得到發展。

近日,PyTorch 官宣要做「無英偉達 CUDA 參與的大模型推理」。在談到為什麼要 100% 使用 Triton 進行探索時,PyTorch 表示:「Triton 提供了一條途徑,使大模型 能夠在不同型別的 GPU 上執行,包括英偉達、AMD英特爾和其他基於 GPU 的加速器。

此外 Triton 還在 Python 中為 GPU 程式設計提供了更高的抽象層,使得使用 PyTorch 能夠比使用供應商特定的 API 更快地編寫高效能核心。」

圖片

在 PyTorch 部落格中討論了使用流行的 LLM 模型(例如 Meta 的 Llama3-8B 和 IBM 的 Granite-8B Code)實現 FP16 推理的方法,其中計算是 100% 使用 OpenAI 的 Triton 語言執行的。

對於使用基於 Triton 核心的模型生成單個 token 的時間,PyTorch 能夠實現在英偉達 H100 GPU 上 Llama 和 Granite 的 CUDA 核心主導工作流程的 0.76-0.78 倍效能,以及在英偉達 A100 GPU 上的 0.62-0.82 倍。

圖片

圖 1. 在英偉達 H100 和 A100 上,Llama3-8B 和 Granite-8B 的 Triton 和 CUDA 變體的推理吞吐量比較。設定:批大小 = 2,輸入序列長度 = 512,輸出序列長度 = 256

也許告別英偉達的時候真要來了。

圖片

Transformer 塊的組成

PyTorch 團隊首先對基於 Transformer 的模型中發生的計算進行細分。下圖顯示了典型 Transformer 塊的「核心(kernel)」。

圖片

圖 2

Llama3 架構的核心操作總結如下:

  • 均方根歸一化(RMSNorm)

  • 矩陣乘法:Fused QKV

  • RoPE

  • 注意力

  • 矩陣乘法:輸出投影

  • RMSNorm

  • 矩陣乘法:Fused Gate + Up Projection

  • 啟用函式:SiLU

  • 點乘(Element Wise Multiplication)

  • 矩陣乘法:Down Projection

這些操作中的每一個都是透過在 GPU 上執行一個(或多個)核心來計算的。雖然每個核心的細節在不同的 Transformer 模型中可能有所不同,但核心操作保持不變。例如,IBM 的 Granite 8B Code 模型在 MLP 層中使用偏置,與 Llama3 不同。此類更改確實需要對核心進行修改。典型的模型是這些 Transformer 塊的堆疊,這些 Transformer 塊透過嵌入層連線在一起。

模型推理

典型的模型架構程式碼與 PyTorch 啟動的 python model.py 檔案共享。在預設的 PyTorch Eager Execution 模式下,這些核心都是使用 CUDA 執行的。為了實現 100% Triton 進行端到端 Llama3-8B 和 Granite-8B 推理,需要編寫和整合手寫 Triton 核心以及利用 torch.compile(生成 Triton 操作)。首先,PyTorch 用編譯器生成的 Triton 核心替換較小的操作,其次,PyTorch 用手寫的 Triton 核心替換更昂貴和複雜的計算(例如矩陣乘法和快閃記憶體注意力)。

Torch.compile 自動為 RMSNorm、RoPE、SiLU 和點乘生成 Triton 核心。使用 Nsight Systems 等工具,可以觀察到這些生成的核心,它們在矩陣乘法和注意力之間表現為微小的深綠色核心。

圖片

圖 3. 使用 torch.compile 跟蹤 Llama3-8B,顯示用於矩陣乘法和快閃記憶體注意力的 CUDA 核心。

對於上面的跟蹤,PyTorch 團隊注意到,在 Llama3-8B 樣式模型中,佔 E2E 延遲 80% 的兩個主要操作是矩陣乘法和注意力核心,並且兩者仍然是 CUDA 核心。因此,為了彌補剩餘的差距,PyTorch 團隊用手寫的 Triton 核心替換了 matmul 和注意力核心。

Triton SplitK GEMM 核心

對於線性層中的矩陣乘法,PyTorch 團隊編寫了一個自定義 FP16 Triton GEMM(通用矩陣 - 矩陣乘法)核心,該核心利用了 SplitK 工作分解。

GEMM 核心調優

為了實現最佳效能,PyTorch 團隊使用窮舉搜尋方法來調整 SplitK GEMM 核心。Granite-8B 和 Llama3-8B 具有如下形狀的線性層:

圖片

圖 4. Granite-8B 和 Llama3-8B 線性層權重矩陣形狀。

每個線性層都有不同的權重矩陣形狀。因此,為了獲得最佳效能,必須針對每個形狀輪廓調整 Triton 核心。在對每個線性層進行調整後,PyTorch 能夠在 Llama3-8B 和 Granite-8B 上實現相對於未調整的 Triton 核心 1.20 倍的 E2E 加速。

Flash Attention 核心

PyTorch 團隊使用不同的配置,對現有 Triton flash attention 核心進行了評估,包括

  • AMD Flash

  • OpenAI Flash

  • Dao AI Lab Flash

  • XFormers Flash

  • PyTorch FlexAttention

PyTorch 團隊分別在 eager 模式和編譯模式下評估了每個核心的文字生成質量。下圖 5 為不同 Flash Attention 核心的比較。

圖片

上圖總結了 PyTorch 觀察到的開箱即用情況,並預計核心 2 到 5 可以在修改後滿足上述標準。不過這也表明,擁有一個可用於基準測試的核心通常只是將它用作端到端生產核心的開始。

PyTorch 團隊選擇在後續測試中使用 AMD flash attention 核心,它透過 torch.compile 進行編譯,並在 eager 和編譯模式下產生清晰的輸出。

為了滿足 torch.compile 與 AMD flash attention 核心的相容性,PyTorch 團隊必須將它定義為 torch 自定義運算元。並且封裝更復雜的 flash attention 核心遵循以下兩個步驟:

一是將函式封裝為一個 PyTorch 自定義運算元。

圖片

二是向該運算元新增一個 FakeTensor 核心,並在給定 flash 輸入張量的形狀(q、k 和 v)時,計算 flash 核心的輸出形狀。

圖片

在將 Triton flash 核心定義為一個自定義 op 後,PyTorch 團隊可以成功地對它進行編譯以實現端到端執行。

圖片

圖 6:在交換 Triton matmul 和 Triton flash attention 核心後,使用 torch.compile 的 Llama3-8B 軌跡。

從圖中可以看到,在整合 SplitK 矩陣乘法核心後,torch op 封裝 flash attention 核心,然後執行 torch.compile,即可實現使用 100% Triton 計算核心的前向傳遞。

端到端基準測試

PyTorch 團隊分別對執行 Granite-8B 和 Llama3-8B 模型的英偉達 H100 和 A100(單 GPU)進行了端到端測試,使用了兩種不同的配置來執行基準測試。

其中 Triton 核心配置使用了:

  • Triton SplitK GEMM

  • AMD Triton Flash Attention

CUDA 核心配置使用了

  • cuBLAS GEMM

  • cuDNN Flash Attention - Scaled Dot-Product Attention (SDPA)

在典型推理設定下,兩種 eager 和 torch 編譯模式的吞吐量和 inter-token 延遲如下圖所示。
圖片

圖 7:H100 和 A100 上 Granite-8B 和 Llama3-8B 單 token 生成延遲(批大小 = 2,輸入序列長度 = 512,輸出序列長度 = 256)。

總的來說,在 H100 上,Triton 模型最高可以達到 CUDA 模型效能的 78%;在 A100 上可以達到 82%。這些效能差距是由 matmul 和 flash attention 的核心延遲造成的。

基準測試

下圖 8 為 Triton 和 CUDA 核心延遲比較(英偉達 H100 上執行 Llama3-8B)。輸入為一個任意 prompt(批大小 = 1,prompt 序列長度 = 44),以解碼延遲時間。

最後結果顯示,Triton matmul 核心比 CUDA 慢了 1.2 至 1.4 倍,而 AMD Triton Flash Attention 比 CUDA SDPA 慢了 1.6 倍。

以上結果凸顯了需要進一步提升 GEMM 和 Flash Attention 等核心原語核心的效能。最近的一些工作(如 FlashAttention-3、FlexAttention) 已經提出了更好地利用底層硬體和 Triton 的方法,PyTorch 希望在它們的基礎上實現更大加速。為了闡明這一點,PyTorch 團隊將 FlexAttention 與 SDPA、AMD’s Triton Flash 核心進行了比較。

PyTorch 團隊 正努力驗證 FlexAttention 的端到端效能。目前,FlexAttention 的初始微基準測試結果表明,在查詢向量較小的情況下,有望實現更長的上下文以及解碼問題形狀。

圖片

圖 9:英偉達 H100 SXM5 80GB 上 FlexAttention 核心基準測試(批大小 = 1,最大頭數 = 32,頭維數 = 128)。

未來工作

未來,PyTorch 團隊計劃探索進一步最佳化 matmuls 的方法,以便更好地利用硬體,併為基於 Triton 的方法實現更大的加速。

對於 flash attention,PyTorch 團隊計劃探索 FlexAttention 和 FlashAttention-3 等核心中使用到的技術,以幫助進一步縮小 Triton 與 CUDA 之間的差距。同時還將探索端到端 FP8 LLM 推理。

原文連結:https://pytorch.org/blog/cuda-free-inference-for-llms/

相關文章