Apache TVM 是一個端到端的深度學習編譯框架,適用於 CPU、GPU 和各種機器學習加速晶片。更多 TVM 中文文件可訪問 → https://tvm.hyper.ai/
作者:Ziheng Jiang
若要在單個迴圈中計算具有相同 shape 的多個輸出,或執行多個值的歸約,例如 argmax。這些問題可以透過元組輸入來解決。
本教程介紹了 TVM 中元組輸入的用法。
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
描述批次計算
對於 shape 相同的運算元,若要在下一個排程過程中一起排程,可以將它們放在一起作為 te.compute 的輸入。
n = te.var("n")
m = te.var("m")
A0 = te.placeholder((m, n), name="A0")
A1 = te.placeholder((m, n), name="A1")
B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name="B")
# 生成的 IR 程式碼:
s = te.create_schedule(B0.op)
print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
輸出結果:
@main = primfn(A0_1: handle, A1_1: handle, B_2: handle, B_3: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {A0: Buffer(A0_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
A1: Buffer(A1_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
B: Buffer(B_4: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto"),
B_1: Buffer(B_5: Pointer(float32), float32, [(stride_3: int32*m)], [], type="auto")}
buffer_map = {A0_1: A0, A1_1: A1, B_2: B, B_3: B_1}
preflattened_buffer_map = {A0_1: A0_3: Buffer(A0_2, float32, [m, n: int32], [stride, stride_4: int32], type="auto"), A1_1: A1_3: Buffer(A1_2, float32, [m, n], [stride_1, stride_5: int32], type="auto"), B_2: B_6: Buffer(B_4, float32, [m, n], [stride_2, stride_6: int32], type="auto"), B_3: B_7: Buffer(B_5, float32, [m, n], [stride_3, stride_7: int32], type="auto")} {
for (i: int32, 0, m) {
for (j: int32, 0, n) {
B[((i*stride_2) + (j*stride_6))] = (A0[((i*stride) + (j*stride_4))] + 2f32)
B_1[((i*stride_3) + (j*stride_7))] = (A1[((i*stride_1) + (j*stride_5))]*3f32)
}
}
}
使用協同輸入(Collaborative Inputs)描述歸約
有時需要多個輸入來表達歸約運算元,並且輸入會協同工作,例如 argmax。在歸約過程中,argmax 要比較運算元的值,還需要保留運算元的索引,可用 te.comm_reducer() 表示:
# x 和 y 是歸約的運算元,它們都是元組的索引和值。
def fcombine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
# 身份元素也要是一個元組,所以 `fidentity` 接收兩種型別作為輸入。
def fidentity(t0, t1):
return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
argmax = te.comm_reducer(fcombine, fidentity, name="argmax")
# 描述歸約計算
m = te.var("m")
n = te.var("n")
idx = te.placeholder((m, n), name="idx", dtype="int32")
val = te.placeholder((m, n), name="val", dtype="int32")
k = te.reduce_axis((0, n), "k")
T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T")
# 生成的 IR 程式碼:
s = te.create_schedule(T0.op)
print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
輸出結果:
@main = primfn(idx_1: handle, val_1: handle, T_2: handle, T_3: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {idx: Buffer(idx_2: Pointer(int32), int32, [(stride: int32*m: int32)], [], type="auto"),
val: Buffer(val_2: Pointer(int32), int32, [(stride_1: int32*m)], [], type="auto"),
T: Buffer(T_4: Pointer(int32), int32, [(stride_2: int32*m)], [], type="auto"),
T_1: Buffer(T_5: Pointer(int32), int32, [(stride_3: int32*m)], [], type="auto")}
buffer_map = {idx_1: idx, val_1: val, T_2: T, T_3: T_1}
preflattened_buffer_map = {idx_1: idx_3: Buffer(idx_2, int32, [m, n: int32], [stride, stride_4: int32], type="auto"), val_1: val_3: Buffer(val_2, int32, [m, n], [stride_1, stride_5: int32], type="auto"), T_2: T_6: Buffer(T_4, int32, [m], [stride_2], type="auto"), T_3: T_7: Buffer(T_5, int32, [m], [stride_3], type="auto")} {
for (i: int32, 0, m) {
T[(i*stride_2)] = -1
T_1[(i*stride_3)] = -2147483648
for (k: int32, 0, n) {
T[(i*stride_2)] = @tir.if_then_else((val[((i*stride_1) + (k*stride_5))] <= T_1[(i*stride_3)]), T[(i*stride_2)], idx[((i*stride) + (k*stride_4))], dtype=int32)
T_1[(i*stride_3)] = @tir.if_then_else((val[((i*stride_1) + (k*stride_5))] <= T_1[(i*stride_3)]), T_1[(i*stride_3)], val[((i*stride_1) + (k*stride_5))], dtype=int32)
}
}
}
使用元組輸入排程操作
雖然一次 batch 操作會有多個輸出,但它們只能一起排程。
n = te.var("n")
m = te.var("m")
A0 = te.placeholder((m, n), name="A0")
B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name="B")
A1 = te.placeholder((m, n), name="A1")
C = te.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name="C")
s = te.create_schedule(C.op)
s[B0].compute_at(s[C], C.op.axis[0])
# 生成的 IR 程式碼:
print(tvm.lower(s, [A0, A1, C], simple_mode=True))
輸出結果:
@main = primfn(A0_1: handle, A1_1: handle, C_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {A0: Buffer(A0_2: Pointer(float32), float32, [(stride: int32*m: int32)], [], type="auto"),
A1: Buffer(A1_2: Pointer(float32), float32, [(stride_1: int32*m)], [], type="auto"),
C: Buffer(C_2: Pointer(float32), float32, [(stride_2: int32*m)], [], type="auto")}
buffer_map = {A0_1: A0, A1_1: A1, C_1: C}
preflattened_buffer_map = {A0_1: A0_3: Buffer(A0_2, float32, [m, n: int32], [stride, stride_3: int32], type="auto"), A1_1: A1_3: Buffer(A1_2, float32, [m, n], [stride_1, stride_4: int32], type="auto"), C_1: C_3: Buffer(C_2, float32, [m, n], [stride_2, stride_5: int32], type="auto")} {
allocate(B.v0: Pointer(global float32), float32, [n]), storage_scope = global;
allocate(B.v1: Pointer(global float32), float32, [n]), storage_scope = global;
for (i: int32, 0, m) {
for (j: int32, 0, n) {
B.v0_1: Buffer(B.v0, float32, [n], [])[j] = (A0[((i*stride) + (j*stride_3))] + 2f32)
B.v1_1: Buffer(B.v1, float32, [n], [])[j] = (A0[((i*stride) + (j*stride_3))]*3f32)
}
for (j_1: int32, 0, n) {
C[((i*stride_2) + (j_1*stride_5))] = (A1[((i*stride_1) + (j_1*stride_4))] + B.v0_1[j_1])
}
}
}
總結
本教程介紹元組輸入操作的用法。
- 描述常規的批次計算。
- 用元組輸入描述歸約操作。
- 注意,只能根據操作而不是張量來排程計算。
下載 Python 原始碼:tuple_inputs.py
下載 Jupyter Notebook:tuple_inputs.ipynb