本篇回答來源於 TVM 官方英文檔案 Lianmin Zheng,Chengfan Jia。更多 TVM 中文檔案可訪問→https://tvm.hyper.ai/
本教程將展示 TVM 的 Auto Scheduling 功能,如何在不編寫自定義模板的情況下,找到最佳 schedule。
與基於模板的 AutoTVM 依賴手工模板來定義搜尋空間不同,auto-scheduler 不需要任何模板。使用者只需編寫計算宣告,無需任何 schedule 命令或模板。auto-scheduler 可以自動生成一個大的搜尋空間,並在空間中找到最優 schedule。
本教程中使用矩陣乘法作為示例。
注意,本教程不會在 Windows 或最新版本的 macOS 上執行。如需執行,請將本教程的主體放在 if name == "__main__": 程式碼塊中。
import os
import numpy as np
import tvm
from tvm import te, auto_scheduler
定義矩陣乘法
首先,定義一個帶有偏置加法的矩陣乘法。注意,這裡使用了 TVM 張量表示式語言中的標準操作。主要區別在於函式定義上方使用了 register_workload 裝飾器。該函式應返回輸入/輸出張量列表。透過這些張量,auto-scheduler 可以得到整個計算圖。
@auto_scheduler.register_workload # Note the auto_scheduler decorator
def matmul_add(N, L, M, dtype):
A = te.placeholder((N, L), name="A", dtype=dtype)
B = te.placeholder((L, M), name="B", dtype=dtype)
C = te.placeholder((N, M), name="C", dtype=dtype)
k = te.reduce_axis((0, L), name="k")
matmul = te.compute(
(N, M),
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
name="matmul",
attrs={"layout_free_placeholders": [B]}, # enable automatic layout transform for tensor B
)
out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
return [A, B, C, out]
建立搜尋任務
定義函式後,可以為 auto_scheduler 建立要搜尋的任務。我們為這個矩陣乘法指定了特定的引數,如這裡是兩個大小為 1024x1024 的矩陣乘法。然後我們建立一個 N=L=M=1024 和 dtype="float32" 的搜尋任務
使用自定義 TARGET 提高效能
為讓 TVM 充分利用特定的硬體平臺,需要手動指定 CPU 功能。例如:
啟用 AVX2:將下面的 llvm 替換為 llvm -mcpu=core-avx2
啟用 AVX-512:將下面的 llvm 替換為 llvm -mcpu=skylake-avx512
target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)
# 檢查計算圖
print("Computational DAG:")
print(task.compute_dag)
輸出結果:
Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])
設定 auto-scheduler 的引數
接下來,為 auto-scheduler 設定引數。
num_measure_trials 表示搜尋過程中可用的測試試驗次數。本教程為了快速演示只進行了 10 次試驗。實際上,1000 是搜尋收斂的最佳數量。可以根據自己的時間預算進行更多試驗。
此外,我們用 RecordToFile 將測試記錄記錄到檔案 matmul.json 中。測試記錄可用於查詢歷史最佳、恢復搜尋以及以後進行更多分析。
有關更多引數,參見 TuningOptions
log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
開始搜尋
準備好所有輸入就可以開始搜尋,讓 auto-scheduler 發揮它的作用。經過一些測試試驗後,可從日誌檔案中載入最佳 schedule 並應用。
# 執行 auto-tuning(搜尋)
task.tune(tune_option)
# 應用最佳 schedule
sch, args = task.apply_best(log_file)
檢查最佳化的 schedule
auto-scheduling 完成後,可將 schedule 降級來檢視 IR。auto-scheduler 執行合適的最佳化,包括多級迴圈切分、佈局轉換、並行化、向量化、迴圈展開和運算元融合。
print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))
輸出結果:
Lowered TIR:
@main = primfn(A_1: handle, B_1: handle, C_1: handle, out_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {A: Buffer(A_2: Pointer(float32), float32, [1048576], []),
B: Buffer(B_2: Pointer(float32), float32, [1048576], []),
C: Buffer(C_2: Pointer(float32), float32, [1048576], []),
out: Buffer(out_2: Pointer(float32), float32, [1048576], [])}
buffer_map = {A_1: A, B_1: B, C_1: C, out_1: out}
preflattened_buffer_map = {A_1: A_3: Buffer(A_2, float32, [1024, 1024], []), B_1: B_3: Buffer(B_2, float32, [1024, 1024], []), C_1: C_3: Buffer(C_2, float32, [1024, 1024], []), out_1: out_3: Buffer(out_2, float32, [1024, 1024], [])} {
allocate(auto_scheduler_layout_transform: Pointer(global float32), float32, [1048576]), storage_scope = global {
for (ax0.ax1.fused.ax2.fused: int32, 0, 128) "parallel" {
for (ax4: int32, 0, 256) {
for (ax6: int32, 0, 4) {
for (ax7: int32, 0, 8) {
auto_scheduler_layout_transform_1: Buffer(auto_scheduler_layout_transform, float32, [1048576], [])[((((ax0.ax1.fused.ax2.fused*8192) + (ax4*32)) + (ax6*8)) + ax7)] = B[((((ax4*4096) + (ax6*1024)) + (ax0.ax1.fused.ax2.fused*8)) + ax7)]
}
}
}
}
for (i.outer.outer.j.outer.outer.fused: int32, 0, 16384) "parallel" {
allocate(matmul: Pointer(global float32x8), float32x8, [4]), storage_scope = global;
for (i.outer.inner: int32, 0, 2) {
matmul_1: Buffer(matmul, float32x8, [4], [])[0] = broadcast(0f32, 8)
matmul_1[1] = broadcast(0f32, 8)
matmul_1[2] = broadcast(0f32, 8)
matmul_1[3] = broadcast(0f32, 8)
for (k.outer: int32, 0, 256) {
for (k.inner: int32, 0, 4) {
let cse_var_2: int32 = (((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8))
let cse_var_1: int32 = ((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner)
{
matmul_1[0] = (matmul_1[0] + (broadcast(A[cse_var_1], 8)*auto_scheduler_layout_transform_1[ramp(cse_var_2, 1, 8)]))
matmul_1[1] = (matmul_1[1] + (broadcast(A[(cse_var_1 + 1024)], 8)*auto_scheduler_layout_transform_1[ramp(cse_var_2, 1, 8)]))
matmul_1[2] = (matmul_1[2] + (broadcast(A[(cse_var_1 + 2048)], 8)*auto_scheduler_layout_transform_1[ramp(cse_var_2, 1, 8)]))
matmul_1[3] = (matmul_1[3] + (broadcast(A[(cse_var_1 + 3072)], 8)*auto_scheduler_layout_transform_1[ramp(cse_var_2, 1, 8)]))
}
}
}
for (i.inner: int32, 0, 4) {
let cse_var_3: int32 = ((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (i.inner*1024)) + (floormod(i.outer.outer.j.outer.outer.fused, 128)*8))
out[ramp(cse_var_3, 1, 8)] = (matmul_1[i.inner] + C[ramp(cse_var_3, 1, 8)])
}
}
}
}
}
檢查正確性並評估效能
構建二進位制檔案並檢查其正確性和效能。
func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np
dev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)
# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)
# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
"Execution time of this operator: %.3f ms"
% (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)
輸出結果:
Execution time of this operator: 93.286 ms
使用記錄檔案
在搜尋過程中,所有測試記錄都儲存到記錄檔案 matmul.json 中。測試記錄可用於重新應用搜尋結果、恢復搜尋和執行其他分析。
下面是從檔案中載入最佳 schedule,並列印等效的 Python schedule API 的例子。可用於除錯和學習 auto-scheduler 的行為。
print("Equivalent python schedule:")
print(task.print_best(log_file))
輸出結果:
Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=4)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=1)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=2)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=8)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=1)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=1)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=4)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
out_i_o_i, out_i_i = s[out].split(out_i, factor=4)
out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=2)
out_j_o_i, out_j_i = s[out].split(out_j, factor=8)
out_j_o_o, out_j_o_i = s[out].split(out_j_o_i, factor=1)
s[out].reorder(out_i_o_o, out_j_o_o, out_i_o_i, out_j_o_i, out_i_i, out_j_i)
s[matmul].compute_at(s[out], out_j_o_i)
out_i_o_o_j_o_o_fused = s[out].fuse(out_i_o_o, out_j_o_o)
s[out].parallel(out_i_o_o_j_o_o_fused)
s[matmul].pragma(matmul_i_o_o_o, "auto_unroll_max_step", 8)
s[matmul].pragma(matmul_i_o_o_o, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)
s[out].vectorize(out_j_i)
恢復搜尋則更為複雜,需要自己建立搜尋策略和 cost model,並透過日誌檔案恢復搜尋策略和 cost model 的狀態。下面的示例進行了 5 次試驗來恢復它們的狀態:
def resume_search(task, log_file):
print("Resume search:")
cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
)
task.tune(tune_option, search_policy=search_policy)
resume_search(task, log_file)
輸出結果:
Resume search:
/usr/local/lib/python3.7/dist-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated. See: https://xgboost.readthedocs.io/en/latest/python/callbacks.html
warnings.warn(f'Old style callback is deprecated. See: {link}', UserWarning)
最後的說明和總結
本教程展示瞭如何在不指定搜尋模板的情況下,使用 TVM Auto-Scheduler 自動最佳化矩陣乘法。從張量表示式(TE)語言開始,演示了一系列關於 TVM 如何最佳化計算操作的示例。
下載 Python 原始碼:auto_scheduler_matmul_x86.py
下載 Jupyter Notebook:auto_scheduler_matmul_x86.ipynb
以上就是該檔案的全部內容,檢視更多 TVM 中文檔案,請訪問→https://tvm.hyper.ai/