Java中神經網路Triton GPU程式設計

banq發表於2024-03-10


在本文中,我們將介紹如何使用程式碼反射在 Java 中實現 Triton 程式設計模型,以替代 Python。

程式碼反射(Code Reflection)是 OpenJDK Project Babylon 專案正在研究和開發的一項 Java 平臺功能。

什麼是Triton
Triton 是一種特定領域程式設計模型和編譯器,開發人員可以用它來編寫可編譯為 GPU 程式碼的 Python 程式。

Triton 使那些對 GPU 硬體和 GPU 特定程式語言(如 CUDA)知之甚少或毫無經驗的開發人員能夠編寫出非常高效的並行程式。

Triton 程式設計模型隱藏了 CUDA 基於執行緒的程式設計模型。因此,Triton 編譯器能夠更好地利用 GPU 硬體,例如,最佳化可能需要顯式同步的情況。

為了實現這種抽象化,開發人員會根據 Triton Python API 進行程式設計,其中的算術計算是在張量而非標量上進行的。這種張量必須具有恆定的形狀、維數和大小(此外,大小必須是 2 的冪次)。

Triton 的釋出公告稱:
Triton 可以讓開發人員以相對較少的工作量達到硬體效能的峰值;例如,它可以用來編寫與 cuBLAS 效能相當的 FP16 矩陣乘法核心--這是許多 GPU 程式設計師無法在 25 行程式碼內做到的。我們的研究人員已經用它編寫出了比同等 Torch 實現效率高達 2 倍的核心,我們很高興能與社群合作,讓每個人都能更方便地使用 GPU 程式設計。

向量加法
為了解釋程式設計模型,我們將介紹一個簡單的例子:向量加法。儘管可以用 CUDA 輕鬆編寫,但這個示例仍具有啟發性。

Triton 網站以教程的形式介紹了完整的示例,包括 Triton 如何與 PyTorch 整合。我們將重點討論 Triton 程式。

import triton
import triton.language as tl


@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,
               # 每個程式應處理的元素數量。
                 # 注:<font>"constexpr "可用作形狀值。
               ):
    有多個
"程式 "在處理不同的資料。我們在這裡識別哪個程式
    #:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    該程式將處理從初始資料偏移的輸入
    # 例如,如果有一個長度為 256、塊大小為 64 的向量,程式
    # 將分別訪問元素 [0:64、64:128、128:192、192:256]。
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 建立掩碼,防止記憶體操作越界訪問。
    mask = offsets < n_elements
    DRAM 中載入 x 和 y,如果輸入不是資料塊大小的
    倍數,則遮蔽掉多餘元素。
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)

程式碼講解點選標題

相關文章