pytorch resnet50 模型轉成 onnx 的程式碼樣例,batch_size 是動態可調節的

ponponon發表於2023-04-03

pytorch resnet50 模型轉成 onnx 的程式碼樣例,batch_size 是動態可調節的

以下是一個將 PyTorch 的 ResNet50 模型轉換為 ONNX 格式的示例程式碼,並使用可變批處理大小:

import torch
import torchvision
import onnx
import onnxruntime

# 建立 PyTorch ResNet50 模型例項
model = torchvision.models.resnet50(pretrained=True)

# 定義一個 PyTorch 張量來模擬輸入資料
batch_size = 4  # 定義批處理大小
input_shape = (batch_size, 3, 224, 224)
input_data = torch.randn(input_shape)

# 將模型轉換為 ONNX 格式
output_path = "resnet50.onnx"
torch.onnx.export(model, input_data, output_path,
                  input_names=["input"], output_names=["output"],
                  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})

# 使用 ONNX 執行時載入模型
session = onnxruntime.InferenceSession(output_path)

# 定義一個 ONNX 張量來模擬輸入資料
new_batch_size = 8  # 定義新的批處理大小
new_input_shape = (new_batch_size, 3, 224, 224)
new_input_data = torch.randn(new_input_shape)

# 在 ONNX 執行時中執行模型
outputs = session.run(["output"], {"input": new_input_data.numpy()})

注意,在將模型匯出為 ONNX 格式時,需要指定 input_names 和 output_names 引數來指定輸入和輸出張量的名稱,以便在 ONNX 執行時中使用。此外,我們還需要使用 dynamic_axes 引數來指定批處理大小的動態維度。最後,在 ONNX 執行時中使用 session.run() 方法來執行模型。

相關文章