【TVM 教程】使用元組輸入(Tuple Inputs)進行計算和歸約

超神经HyperAI發表於2024-11-19

Apache TVM 是一個端到端的深度學習編譯框架,適用於 CPU、GPU 和各種機器學習加速晶片。更多 TVM 中文文件可訪問 → https://tvm.hyper.ai/

作者:Ziheng Jiang

若要在單個迴圈中計算具有相同 shape 的多個輸出,或執行多個值的歸約,例如 argmax。這些問題可以透過元組輸入來解決。

本教程介紹了 TVM 中元組輸入的用法。

from __future__ import absolute_import, print_function

import tvm
from tvm import te
import numpy as np

描述批次計算

對於 shape 相同的運算元,若要在下一個排程過程中一起排程,可以將它們放在一起作為 te.compute 的輸入。

n = te.var("n")
m = te.var("m")
A0 = te.placeholder((m, n), name="A0")
A1 = te.placeholder((m, n), name="A1")
B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name="B")

# 生成的 IR 程式碼:
s = te.create_schedule(B0.op)
print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))

輸出結果:

@main = primfn(A0_1: handle, A1_1: handle, B_2: handle, B_3: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A0: Buffer(A0_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             A1: Buffer(A1_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
             B: Buffer(B_4: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto"),
             B_1: Buffer(B_5: Pointer(float32), float32, [(stride_3: int32*m)], [], type="auto")}
  buffer_map = {A0_1: A0, A1_1: A1, B_2: B, B_3: B_1}
  preflattened_buffer_map = {A0_1: A0_3: Buffer(A0_2, float32, [m, n: int32], [stride, stride_4: int32], type="auto"), A1_1: A1_3: Buffer(A1_2, float32, [m, n], [stride_1, stride_5: int32], type="auto"), B_2: B_6: Buffer(B_4, float32, [m, n], [stride_2, stride_6: int32], type="auto"), B_3: B_7: Buffer(B_5, float32, [m, n], [stride_3, stride_7: int32], type="auto")} {
  for (i: int32, 0, m) {
    for (j: int32, 0, n) {
      B[((i*stride_2) + (j*stride_6))] = (A0[((i*stride) + (j*stride_4))] + 2f32)
      B_1[((i*stride_3) + (j*stride_7))] = (A1[((i*stride_1) + (j*stride_5))]*3f32)
    }
  }
}

使用協同輸入(Collaborative Inputs)描述歸約

有時需要多個輸入來表達歸約運算元,並且輸入會協同工作,例如 argmax。在歸約過程中,argmax 要比較運算元的值,還需要保留運算元的索引,可用 te.comm_reducer() 表示:

# x 和 y 是歸約的運算元,它們都是元組的索引和值。
def fcombine(x, y):
    lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
    rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
    return lhs, rhs

# 身份元素也要是一個元組,所以 `fidentity` 接收兩種型別作為輸入。
def fidentity(t0, t1):
    return tvm.tir.const(-1, t0), tvm.te.min_value(t1)

argmax = te.comm_reducer(fcombine, fidentity, name="argmax")

# 描述歸約計算
m = te.var("m")
n = te.var("n")
idx = te.placeholder((m, n), name="idx", dtype="int32")
val = te.placeholder((m, n), name="val", dtype="int32")
k = te.reduce_axis((0, n), "k")
T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T")

# 生成的 IR 程式碼:
s = te.create_schedule(T0.op)
print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))

輸出結果:

@main = primfn(idx_1: handle, val_1: handle, T_2: handle, T_3: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {idx: Buffer(idx_2: Pointer(int32), int32, [(stride: int32*m: int32)], [], type="auto"),
             val: Buffer(val_2: Pointer(int32), int32, [(stride_1: int32*m)], [], type="auto"),
             T: Buffer(T_4: Pointer(int32), int32, [(stride_2: int32*m)], [], type="auto"),
             T_1: Buffer(T_5: Pointer(int32), int32, [(stride_3: int32*m)], [], type="auto")}
  buffer_map = {idx_1: idx, val_1: val, T_2: T, T_3: T_1}
  preflattened_buffer_map = {idx_1: idx_3: Buffer(idx_2, int32, [m, n: int32], [stride, stride_4: int32], type="auto"), val_1: val_3: Buffer(val_2, int32, [m, n], [stride_1, stride_5: int32], type="auto"), T_2: T_6: Buffer(T_4, int32, [m], [stride_2], type="auto"), T_3: T_7: Buffer(T_5, int32, [m], [stride_3], type="auto")} {
  for (i: int32, 0, m) {
    T[(i*stride_2)] = -1
    T_1[(i*stride_3)] = -2147483648
    for (k: int32, 0, n) {
      T[(i*stride_2)] = @tir.if_then_else((val[((i*stride_1) + (k*stride_5))] <= T_1[(i*stride_3)]), T[(i*stride_2)], idx[((i*stride) + (k*stride_4))], dtype=int32)
      T_1[(i*stride_3)] = @tir.if_then_else((val[((i*stride_1) + (k*stride_5))] <= T_1[(i*stride_3)]), T_1[(i*stride_3)], val[((i*stride_1) + (k*stride_5))], dtype=int32)
    }
  }
}

使用元組輸入排程操作

雖然一次 batch 操作會有多個輸出,但它們只能一起排程。

n = te.var("n")
m = te.var("m")
A0 = te.placeholder((m, n), name="A0")
B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name="B")
A1 = te.placeholder((m, n), name="A1")
C = te.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name="C")

s = te.create_schedule(C.op)
s[B0].compute_at(s[C], C.op.axis[0])
# 生成的 IR 程式碼:
print(tvm.lower(s, [A0, A1, C], simple_mode=True))

輸出結果:

@main = primfn(A0_1: handle, A1_1: handle, C_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {A0: Buffer(A0_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
             A1: Buffer(A1_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
             C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")}
  buffer_map = {A0_1: A0, A1_1: A1, C_1: C}
  preflattened_buffer_map = {A0_1: A0_3: Buffer(A0_2, float32, [m, n: int32], [stride, stride_3: int32], type="auto"), A1_1: A1_3: Buffer(A1_2, float32, [m, n], [stride_1, stride_4: int32], type="auto"), C_1: C_3: Buffer(C_2, float32, [m, n], [stride_2, stride_5: int32], type="auto")} {
  allocate(B.v0: Pointer(global float32), float32, [n]), storage_scope = global;
  allocate(B.v1: Pointer(global float32), float32, [n]), storage_scope = global;
  for (i: int32, 0, m) {
    for (j: int32, 0, n) {
      B.v0_1: Buffer(B.v0, float32, [n], [])[j] = (A0[((i*stride) + (j*stride_3))] + 2f32)
      B.v1_1: Buffer(B.v1, float32, [n], [])[j] = (A0[((i*stride) + (j*stride_3))]*3f32)
    }
    for (j_1: int32, 0, n) {
      C[((i*stride_2) + (j_1*stride_5))] = (A1[((i*stride_1) + (j_1*stride_4))] + B.v0_1[j_1])
    }
  }
}

總結

本教程介紹元組輸入操作的用法。

  • 描述常規的批次計算。
  • 用元組輸入描述歸約操作。
  • 注意,只能根據操作而不是張量來排程計算。

下載 Python 原始碼:tuple_inputs.py

下載 Jupyter Notebook:tuple_inputs.ipynb

相關文章