【推理引擎】如何在 ONNXRuntime 中新增新的運算元

虔誠的樹發表於2022-03-30

如果模型中有些運算元不被ONNX運算元庫支援,我們就需要利用ONNXRuntime提供的API手動新增新運算元。在官方文件中已經對如何新增定製運算元進行了介紹(https://onnxruntime.ai/docs/reference/operators/add-custom-op.html ),這裡我們主要把原始碼中對應的流程給捋清楚。

新增定製運算元(Custom Operators)主要分為三步:

  1. 建立一個定製運算元域(CusttomOpDomain);
  2. 建立一個定製運算元(CustomOp),並將該運算元新增到定製運算元域中;
  3. 將定製運算元域新增到 SessionOption 中

首先看看原始碼中給出的定製運算元樣例:

// file path: onnxruntime/test/shared_lib/custom_op_utils.h

// 首先定義定製運算元的核
struct MyCustomKernel {
  MyCustomKernel(Ort::CustomOpApi ort, const OrtKernelInfo* /*info*/, void* compute_stream)
      : ort_(ort), compute_stream_(compute_stream) {
  }

  void Compute(OrtKernelContext* context);

 private:
  Ort::CustomOpApi ort_;
  void* compute_stream_;
};

// 然後定義定製運算元的各個操作,各個成員函式均已實現,其中 CreateKernel 會返回前面定義的運算元核物件
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustomKernel> {
  explicit MyCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {}

  void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { return new MyCustomKernel(api, info, compute_stream_); };
  const char* GetName() const { return "Foo"; };
  const char* GetExecutionProviderType() const { return provider_; };

  size_t GetInputTypeCount() const { return 2; };
  ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
    // Both the inputs need to be necessarily of float type
    return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
  };

  size_t GetOutputTypeCount() const { return 1; };
  ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };

 private:
  const char* provider_;
  void* compute_stream_;
};

在上面程式碼中,我們看到定製運算元繼承自 Ort::CustomOpBase<MyCustomOp, MyCustomKernel>,這種擴充套件類作為模板基類的模板引數的方式又被稱為CRTP,接著深入到這個模板類內部:

// file path: include/onnxruntime/core/session/onnxruntime_cxx_api.h

template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
  CustomOpBase() {
    OrtCustomOp::version = ORT_API_VERSION;
    OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
    OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };

    OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };

    OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
    OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };

    OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
    OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };

    OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
    OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };

    OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
    OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
  }

  // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
  const char* GetExecutionProviderType() const { return nullptr; }

  // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
  // (inputs and outputs are required by default)
  OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
    return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
  }

  OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
    return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
  }
};

這裡的 CustomOpBase 又繼承自 OrtCustomOp

// include/onnxruntime/core/session/onnxruntime_c_api.h

struct OrtCustomOp;
typedef struct OrtCustomOp OrtCustomOp;

struct OrtCustomOp {
  uint32_t version;  // Must be initialized to ORT_API_VERSION

  // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below.
  void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api,
                                    _In_ const OrtKernelInfo* info);

  // Returns the name of the op
  const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op);

  // Returns the type of the execution provider, return nullptr to use CPU execution provider
  const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op);

  // Returns the count and types of the input & output tensors
  ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
  size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op);
  ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
  size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op);

  // Op kernel callbacks
  void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
  void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel);

  // Returns the characteristics of the input & output tensors
  OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
  OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index);
};

可以發現,OrtCustomOp 中定義了定製運算元應該實現的模式,其中的一系列回撥函式由其派生類一一實現,比如上文提到的 CustomOpBase 在其建構函式中,以 lambda 函式的方式實現各個回撥函式。

至此,我們已經完整地梳理了定義定製運算元在原始碼內部是如何實現的,接下來介紹如何將定義好的定製運算元使用起來。

從如下官方測試程式碼開始分析:

// file path: onnxruntime/test/shared_lib/test_inference.cc

TEST(CApiTest, custom_op_handler) {
  std::cout << "Running custom op inference" << std::endl;

  std::vector<Input> inputs(1);
  Input& input = inputs[0];
  input.name = "X";
  input.dims = {3, 2};
  input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};

  // prepare expected inputs and outputs
  std::vector<int64_t> expected_dims_y = {3, 2};
  std::vector<float> expected_values_y = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f};

  // 建立定製運算元(MyCustomOp)
#ifdef USE_CUDA
  cudaStream_t compute_stream = nullptr;    // 宣告一個 cuda stream
  cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking);  // 建立一個 cuda stream
  MyCustomOp custom_op{onnxruntime::kCudaExecutionProvider, compute_stream};
#else
  MyCustomOp custom_op{onnxruntime::kCpuExecutionProvider, nullptr};
#endif
  
  // 建立定製運算元域(CustomOpDomain)
  Ort::CustomOpDomain custom_op_domain("");
  // 在定製運算元域中新增定製運算元
  custom_op_domain.Add(&custom_op);

  // 進入 TestInference
#ifdef USE_CUDA
  TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 1,
                       custom_op_domain, nullptr, nullptr, false, compute_stream);
  cudaStreamDestroy(compute_stream);
#else
  TestInference<float>(*ort_env, CUSTOM_OP_MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, 0,
                       custom_op_domain, nullptr);
#endif
}

以上程式碼需要特別注意的是,需要根據巨集(USE_CUDA)用來判斷是否使用CUDA。如果使用 CUDA:

  • 當模型執行在GPU上,而插入的是 CPU 定製運算元,那麼 ONNXRuntime 會在 CPU 定製運算元前後分別插入兩個操作 MemcpyToHost、MemcpyFromHost,這兩個操作負責記憶體拷貝,即首先從 Device 拷貝到 Host,再從 Host 拷貝到 Device;
  • 如果插入的是 GPU 定製運算元,為了確保 ORT 的 CUDA kernels 和定製 CUDA kernels 之間的同步,它們必須使用同一個 CUDA 計算流。具體細節在下一個程式碼繼續分析。

這裡建立 cuda stream 的方式是 cudaStreamCreateWithFlags,該函式和 cudaStreamCreate 不同,後者在多次呼叫時是序列方式執行,而前者可同步執行。如果將引數 cudaStreamNonBlocking 替換為 cudaStreamDefault,則 cudaStreamCreateWithFlags 的行為將和 cudaStreamCreate 相同。【參考內容:CUDA 5.0 中cudaStreamCreateWithFlags 的用法

無論是否使用CDUA,我們都需要建立定製運算元(MyCustomOp)。

進入 TestInference 函式內部:

// file path: onnxruntime/test/shared_lib/test_inference.cc

template <typename OutT>
static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& model_uri,
                          const std::vector<Input>& inputs,
                          const char* output_name,
                          const std::vector<int64_t>& expected_dims_y,
                          const std::vector<OutT>& expected_values_y,
                          int provider_type,
                          OrtCustomOpDomain* custom_op_domain_ptr,
                          const char* custom_op_library_filename,
                          void** library_handle = nullptr,
                          bool test_session_creation_only = false,
                          void* cuda_compute_stream = nullptr) {
  Ort::SessionOptions session_options;

  if (provider_type == 1) {
#ifdef USE_CUDA
    std::cout << "Running simple inference with cuda provider" << std::endl;
    auto cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream);
    session_options.AppendExecutionProvider_CUDA(cuda_options);
#else
    ORT_UNUSED_PARAMETER(cuda_compute_stream);
    return;
#endif
  } else if (provider_type == 2) {
#ifdef USE_DNNL
    Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Dnnl(session_options, 1));
    std::cout << "Running simple inference with dnnl provider" << std::endl;
#else
    return;
#endif
  } else if (provider_type == 3) {
#ifdef USE_NUPHAR
    Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nuphar(session_options,
                                                                      /*allow_unaligned_buffers*/ 1, ""));
    std::cout << "Running simple inference with nuphar provider" << std::endl;
#else
    return;
#endif
  } else {
    std::cout << "Running simple inference with default provider" << std::endl;
  }
  if (custom_op_domain_ptr) {
    session_options.Add(custom_op_domain_ptr);
  }

  if (custom_op_library_filename) {
    Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary(session_options,
                                                             custom_op_library_filename, library_handle));
  }

  // if session creation passes, model loads fine
  Ort::Session session(env, model_uri.c_str(), session_options);

  // caller wants to test running the model (not just loading the model)
  if (!test_session_creation_only) {
    // Now run
    auto default_allocator = std::make_unique<MockedOrtAllocator>();

    //without preallocated output tensor
    RunSession<OutT>(default_allocator.get(),
                     session,
                     inputs,
                     output_name,
                     expected_dims_y,
                     expected_values_y,
                     nullptr);
    //with preallocated output tensor
    Ort::Value value_y = Ort::Value::CreateTensor<float>(default_allocator.get(),
                                                         expected_dims_y.data(), expected_dims_y.size());

    //test it twice
    for (int i = 0; i != 2; ++i)
      RunSession<OutT>(default_allocator.get(),
                       session,
                       inputs,
                       output_name,
                       expected_dims_y,
                       expected_values_y,
                       &value_y);
  }
}

前文提到,如果對應EP是CUDA,需要確保 ORT 的 CUDA kernels 和定製 CUDA kernels 之間的同步。為了實現這一目標,首先通過 CreateDefaultOrtCudaProviderOptionsWithCustomStream 函式將新建立的 CUDA 計算流以 OrtCudaProviderOptions 的形式傳遞給 SessionOptions:

OrtCUDAProviderOptions cuda_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(cuda_compute_stream);
session_options.AppendExecutionProvider_CUDA(cuda_options)

之後,將定製運算元域也新增到 SessionOptions 中:

if (custom_op_domain_ptr) {
  session_options.Add(custom_op_domain_ptr);
}

至此,SessionOptions 已經構建完成,下面建立 Session 並通過 model_uri 載入模型:

Ort::Session session(env, model_uri.c_str(), session_options);

這裡的(1)Ort::Session 是在 onnxruntime_cxx_api.h 檔案中宣告的類,(2)對應的建構函式在 onnxruntime_cxx_inline.h 中實現,(3)實現方式是進一步呼叫 onnxruntime_c_api.h 中定義的 API,該 API 也僅僅是宣告,(4)最終對應的實現在 onnxruntime_c_api.cc 檔案中:

// (1) include/onnxruntime/core/session/onnxruntime_cxx_api.h
struct Session : Base<OrtSession> {
  explicit Session(std::nullptr_t) {}
  Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
}

// (2) include/onnxruntime/core/session/onnxruntime_cxx_inline.h
inline Session::Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
  ThrowOnError(GetApi().CreateSession(env, model_path, options, &p_));
}

// (3) include/onnxruntime/core/session/onnxruntime_c_api.h
ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
                _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out);

// (4) onnxruntime/core/session/onnxruntime_c_api.cc
ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
                    _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) {
  API_IMPL_BEGIN
  std::unique_ptr<onnxruntime::InferenceSession> sess;
  OrtStatus* status = nullptr;
  *out = nullptr;

  ORT_TRY {
    ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess));
    ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess));

    *out = reinterpret_cast<OrtSession*>(sess.release());
  }
  ORT_CATCH(const std::exception& e) {
    ORT_HANDLE_EXCEPTION([&]() {
      status = OrtApis::CreateStatus(ORT_FAIL, e.what());
    });
  }

  return status;
  API_IMPL_END
}

可以發現,Ort::Session 內部還是呼叫了 onnxruntime::InferenceSession

扯遠了,下面迴歸主題。

建立 Session 完成之後,便開始執行,進入 RunSession 函式內部:

// file path: onnxruntime/test/shared_lib/test_inference.cc

template <typename OutT>
void RunSession(OrtAllocator* allocator, Ort::Session& session_object,
                const std::vector<Input>& inputs,
                const char* output_name,
                const std::vector<int64_t>& dims_y,
                const std::vector<OutT>& values_y,
                Ort::Value* output_tensor) {
  
  // 構建模型輸入
  std::vector<Ort::Value> ort_inputs;
  std::vector<const char*> input_names;
  for (size_t i = 0; i < inputs.size(); i++) {
    input_names.emplace_back(inputs[i].name);
    ort_inputs.emplace_back(
        Ort::Value::CreateTensor<float>(allocator->Info(allocator), const_cast<float*>(inputs[i].values.data()),
                                        inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size()));
  }
  
  // 執行 RUN
  std::vector<Ort::Value> ort_outputs;
  if (output_tensor)
    session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
                       &output_name, output_tensor, 1);
  else {
    ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
                                     &output_name, 1);
    ASSERT_EQ(ort_outputs.size(), 1u);
    output_tensor = &ort_outputs[0];
  }

  auto type_info = output_tensor->GetTensorTypeAndShapeInfo();
  ASSERT_EQ(type_info.GetShape(), dims_y);
  size_t total_len = type_info.GetElementCount();
  ASSERT_EQ(values_y.size(), total_len);

  OutT* f = output_tensor->GetTensorMutableData<OutT>();
  for (size_t i = 0; i != total_len; ++i) {
    ASSERT_EQ(values_y[i], f[i]);
  }
}

這裡使用了一些GTest中的斷言來判定執行結果是否符合預期。

至此,我們已經完整地分析了定製運算元從定義到使用的全部流程。

文件中還提到了 Contrib ops,這類運算元歸屬於 contrib ops domain,是嵌入到 runtime 內部的,對於一些使用低頻的運算元最好不要加入這個域中,否則會導致執行時庫(runtime library)過大。
官方文件中給出了新增運算元到這個域中的方法,這裡就不再進行介紹了,以後用到了再說吧。

相關文章