ONNXRuntime學習筆記(三)

zq發表於2022-05-01

接上一篇完成的pytorch模型訓練結果,模型結構為ResNet18+fc,引數量約為11M,最終測試集Acc達到94.83%。接下來有分兩個部分:匯出onnx和使用onnxruntime推理。

一、pytorch匯出onnx

直接放函式吧,這部分我是直接放在test.py裡面的,直接從dataloader中拿到一個batch的資料走一遍推理即可。

def export_onnx(net, testloader, output_file):
    net.eval()
    with torch.no_grad():
        for data in testloader: 
            images, labels = data

            torch.onnx.export(net, 
                            (images), 
                            output_file,
                            training=False,
                            do_constant_folding=True,
                            input_names=["img"], 
                            output_names=["output"],
                  dynamic_axes={"img": {0: "b"},"output": {0: "b"}}
                  )
            print("onnx export done!")
            break

上面函式中幾個比較重要的引數:do_constant_folding是常量摺疊,建議開啟;輸入張量通過一個tuple傳入,並且最好指定每個輸入和輸出的名稱,此外,為保證使用onnxruntime推理的時候batchsize可變,dynamic_axes的第一維需要像上述一樣設定為動態的。如果是全卷積做分割的網路,類似的輸入h和w也應該是動態的。

單獨執行test.py計算測試集效果和平均相應時間,為方便比較,這裡batch_size設定為1,結果為:

Test Acc is: 94.84%
Average response time cost: 8.703978610038757 ms

二、使用onnxruntime推理

這裡我們使用gpu版本的onnxruntime庫進行推理,其python包可直接pip install onnxruntime-gpu安裝。onnxruntime推理程式碼和測試集推理程式碼很類似,如下:

import numpy as np
import onnxruntime as ort
import argparse, os
from lib import CIFARDataset

def onnxruntime_test(session, testloader):
    print("Start Testing!")
    input_name = session.get_inputs()[0].name
    correct = 0
    total = 0   # 計數歸零(初始化)
    for data in testloader:
        images, labels = data
        images, labels = images.numpy(), labels.numpy()
        outputs = session.run(None, {input_name:images})
        predicted = np.argmax(outputs[0], axis=1)  # 取得分最高的那個類
        total += labels.shape[0]                        # 累加樣本總數
        correct += (predicted == labels).sum()        # 累加預測正確的樣本個數
    acc = correct / total
    print('ONNXRuntime Test Acc is: %.2f%%' % (100*acc))
            
if __name__ == '__main__':
    # 命令列引數解析
    parser = argparse.ArgumentParser("CNN backbone on cifar10")
    parser.add_argument('--onnx', default='./output/test_resnet18_10_autoaug/densenet_best.onnx')
    args = parser.parse_args()

    NUM_CLASS =10
    BATCH_SIZE = 1 # 批處理尺寸(batch_size)

    # 資料集迭代器
    data_path="./data"
    dataset = CIFARDataset(dataset_path=data_path, batchsize=BATCH_SIZE)
    _, testloader = dataset.get_cifar10_dataloader()

    # 構建session
    sess = ort.InferenceSession(args.onnx, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

    #onnxruntime推理
    import time
    start = time.time()
    onnxruntime_test(sess, testloader)
    end = time.time()
    print(f"Average response time cost: {1000*(end-start)/len(testloader.dataset)} ms")

使用onnxruntime載入匯出的onnx模型,計算測試集效果和平均響應時間,結果為:

ONNXRuntime Test Acc is: 94.83%
Average response time cost: 3.1050602436065673 ms

三、小結

分析上面的pytorch和onnxruntime的測試結果可知,最終測試集效果是一致的,Acc分別為94.84%和94.83%,相當於10000個樣本里面只有1個的預測結果不一致,這是可以接受範圍內。但onnxruntime的效率更高,平均耗時只有3.1ms,比pytorch的8.7ms快了將近3倍。這在實際部署中的優勢是非常明顯的。目前Python端的結論比最初目標設定的50ms高很多,如果說需要進一步優化,兩個方向:模型量化或並行化推理(拼batch或多執行緒)。下一篇再分析。

相關文章