ONNX Runtime 原始碼閱讀:Graph::SetGraphInputsOutputs() 函式

虔誠的樹發表於2022-05-04

前言

為了深入理解ONNX Runtime的底層機制,本文將對 Graph::SetGraphInputsOutputs() 的程式碼逐行分析。

正文

首先判斷Graph是否從ONNX檔案中載入所得:

if (is_loaded_from_model_file_) return Status::OK();

如果是,可直接返回;如果不是,則需要解析Graph中的節點,從而設定模型的輸入和輸出。

Graph中的成員變數 value_info_graph_inputs_excluding_initializers_graph_inputs_including_initializers_ 以及 graph_outputs_ 全部清空:

value_info_.clear();

graph_inputs_excluding_initializers_.clear();

if (!graph_inputs_manually_set_) {
  graph_inputs_including_initializers_.clear();
} else {
  std::unordered_set<std::string> existing_names;
  for (auto arg : graph_inputs_including_initializers_) {
    const std::string& name = arg->Name();
    if (existing_names.count(name) == 0) {
      graph_inputs_excluding_initializers_.push_back(arg);
      existing_names.insert(name);
    }
  }
}

if (!graph_outputs_manually_set_) {
  graph_outputs_.clear();
}

設定一些區域性變數,方便下面的使用分析:

std::unordered_map<std::string, size_t> output_name_to_node_arg_index;
std::vector<const NodeArg*> output_node_args_in_order;
std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};

統計所有節點的輸出,新增到以上區域性變數(output_name_to_node_arg_index 和 output_node_args_in_order)中:

for (const auto& node : Nodes()) {
  for (const auto* output_def : node.OutputDefs()) {
    if (output_def->Exists()) {
      output_node_args_in_order.push_back(output_def);
      output_name_to_node_arg_index.insert({output_def->Name(), output_node_args_in_order.size() - 1});
    }
  }
}
auto graph_output_args = output_name_to_node_arg_index;  // 拷貝一份輸出節點map

然後遍歷圖中每個節點以及每個節點的輸入:

for (const auto& node : Nodes()) {
  // Go thru all node's inputs.
  for (const auto* input_arg : node.InputDefs()) {
    ...
  }
}

在輸出節點name列表中查詢當前輸入name

auto output_arg_iter = output_name_to_node_arg_index.find(input_arg->Name());

如果沒有找到,說明這個節點的輸入就是圖的輸入,接下來還需要判斷這個輸入是否已經放在區域性變數added_input_names中:

if (output_name_to_node_arg_index.end() == output_arg_iter) {
  // This input arg is not the output of another node so must come from either a graph input or an initializer.
  const std::string& name = input_arg->Name();
  if (added_input_names.end() == added_input_names.find(name)) {
    ...
  }
}

如果已經放到區域性變數added_input_names中,就可以判斷節點的下一個輸入或者下一個節點的輸入。如果沒有放到區域性變數added_input_names中:

bool is_initializer = name_to_initial_tensor_.find(name) != name_to_initial_tensor_.end();  // 判斷當前input_arg是否已初始化過的tensor,如果是就不可以再放置到 graph_inputs_excluding_initializers_ 中
if (!graph_inputs_manually_set_) {   // 如果未主動呼叫 SetInputs() 方法
  // if IR version < 4 all initializers must have a matching graph input
  // (even though the graph input is not allowed to override the initializer).
  // if IR version >= 4 initializers are not required to have a matching graph input.
  // any graph inputs that are to override initializers must be specified by calling SetInputs.
  if (!is_initializer || ir_version_ < 4) {
    graph_inputs_including_initializers_.push_back(input_arg);
  }
  if (!is_initializer) {
    // If input_arg is not of an initializer, we add it into graph_inputs_excluding_initializers_.
    graph_inputs_excluding_initializers_.push_back(input_arg);
  }
} else {  // 如果主動呼叫了 SetInputs() 方法
  // graph_inputs_including_initializers_ has been manually populated by SetInputs.
  // Validation: the <input_arg> must be in graph inputs or initializers when it's manually set.
  if (!is_initializer) {
    const auto& inputs = graph_inputs_including_initializers_;
    bool in_inputs = std::find(inputs.begin(), inputs.end(), input_arg) != inputs.end();
    if (!in_inputs) {
      return Status(ONNXRUNTIME, FAIL,
                    name + " must be either specified in graph inputs or graph initializers.");
    }
  } else {
    // If arg_input is of an initializer, we remove it from graph_inputs_excluding_initializers_
    // whose initial content has both initializers and non-initializers.
    auto input_pos = std::find(graph_inputs_excluding_initializers_.begin(),
                                graph_inputs_excluding_initializers_.end(),
                                input_arg);
    if (input_pos != graph_inputs_excluding_initializers_.end()) {
      graph_inputs_excluding_initializers_.erase(input_pos);
    }
  }
}
added_input_names.insert(name);

可以看到,這裡會把當前的 input_arg 分別放到 graph_inputs_including_initializers_graph_inputs_excluding_initializers_ 中,並將name放在added_input_names中。

如果該輸入的name已經在輸出節點name列表中,說明這個節點是中間輸出結果,而非整個圖的輸出,因此應該將其從圖的輸出(graph_output_args)中刪除,並放在 value_info_ 中:

if (output_name_to_node_arg_index.end() == output_arg_iter) {
  ...
}else if(graph_output_args.erase(output_arg_iter->first) >= 1){
  value_info_.insert(input_arg);
}

以上我們對Graph的三個成員變數:graph_inputs_including_initializers_graph_inputs_excluding_initializers_value_info_分別進行了賦值,其中前兩者儲存輸入,後者儲存中間結果。我們還需要處理圖的輸出結果:`graph_outputs_`:

if (!graph_outputs_manually_set_) {
  // Set graph outputs in order.
  std::vector<size_t> graph_output_args_index;
  graph_output_args_index.reserve(graph_output_args.size());
  for (const auto& output_arg : graph_output_args) {          // graph_output_args原本儲存了所有節點的輸出,但是前面的程式碼已經把中間節點的輸出給移除了,因此剩下的就是整個Graph的輸出
    graph_output_args_index.push_back(output_arg.second);
  }

  std::sort(graph_output_args_index.begin(), graph_output_args_index.end());
  for (auto& output_arg_index : graph_output_args_index) {
    graph_outputs_.push_back(output_node_args_in_order[output_arg_index]);
  }
}

最後,還需要對 graph_overridable_initializers_ 進行處理:

ComputeOverridableInitializers();

進入這個函式內部:

void Graph::ComputeOverridableInitializers() {
  graph_overridable_initializers_.clear();
  if (CanOverrideInitializer()) {
    // graph_inputs_excluding_initializers_ and graph_inputs_including_initializers_
    // are inserted in the same order. So we walk and compute the difference.
    auto f_incl = graph_inputs_including_initializers_.cbegin();
    const auto l_incl = graph_inputs_including_initializers_.cend();
    auto f_excl = graph_inputs_excluding_initializers_.cbegin();
    const auto l_excl = graph_inputs_excluding_initializers_.cend();

    while (f_incl != l_incl) {
      // Equal means not an initializer
      if (f_excl != l_excl && *f_incl == *f_excl) {
        ++f_incl;
        ++f_excl;
        continue;
      }
      graph_overridable_initializers_.push_back(*f_incl);
      ++f_incl;
    }
  }
}

這是一個很簡單的演算法,通過比較 graph_inputs_including_initializers_graph_inputs_excluding_initializers_,提取出 initializer 並放置到 graph_overridable_initializers_ 中。

至此,我們完成了對 Graph::SetGraphInputsOutputs() 函式的解析。

總結

針對這個函式的解析不僅理解了如何從Graph的nodes中分析出graph的輸入和輸出,而且懂得了graph_overridable_initializers_以及value_info_的作用。

相關文章