現在的深度學習框架一般都是基於 Python 來實現,構建、訓練、儲存和呼叫模型都可以很容易地在 Python 下完成。但有時候,我們在實際應用這些模型的時候可能需要在其他程式語言下進行,本文將通過直接呼叫 TensorFlow 的 C/C++ 介面來匯入 TensorFlow 預訓練好的模型。
1.環境配置 點此檢視 C/C++ 介面的編譯
2. 匯入預定義的圖和訓練好的引數值
// set up your input paths
const string pathToGraph = "/home/senius/python/c_python/test/model-10.meta";
const string checkpointPath = "/home/senius/python/c_python/test/model-10";
auto session = NewSession(SessionOptions()); // 建立會話
if (session == nullptr)
{
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);  // 匯入圖模型
if (!status.ok())
{
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());  // 將圖模型加入到會話中
if (!status.ok())
{
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath; // 讀取預訓練好的權重
status = session->Run({{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor},}, {},
{graph_def.saver_def().restore_op_name()}, nullptr);
if (!status.ok())
{
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
複製程式碼
3. 準備測試資料
const string filename = "/home/senius/python/c_python/test/04t30t00.npy";
//Read TXT data to array
float Array[1681*41];
ifstream is(filename);
for (int i = 0; i < 1681*41; i++){
is >> Array[i];
}
is.close();
tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 41, 41, 41, 1}));
auto input_tensor_mapped = input_tensor.tensor<float, 5>();
float *pdata = Array;
// copying the data into the corresponding tensor
for (int x = 0; x < 41; ++x)//depth
{
for (int y = 0; y < 41; ++y) {
for (int z = 0; z < 41; ++z) {
const float *source_value = pdata + x * 1681 + y * 41 + z;
input_tensor_mapped(0, x, y, z, 0) = *source_value;
}
}
}
複製程式碼
- 本例中輸入資料是一個 [None, 41, 41, 41, 1] 的張量,我們需要先從 TXT 檔案中讀出測試資料,然後正確地填充到張量中去。
4. 前向傳播得到預測值
std::vector<tensorflow::Tensor> finalOutput;
std::string InputName = "X"; // Your input placeholder's name
std::string OutputName = "sigmoid"; // Your output tensor's name
vector<std::pair<string, Tensor> > inputs;
inputs.push_back(std::make_pair(InputName, input_tensor));
// Fill input tensor with your input data
session->Run(inputs, {OutputName}, {}, &finalOutput);
auto output_y = finalOutput[0].scalar<float>();
std::cout << output_y() << "\n";
複製程式碼
- 通過給定輸入和輸出張量的名字,我們可以將測試資料傳入到模型中,然後進行前向傳播得到預測值。
5. 一些問題
- 本模型是在 TensorFlow 1.4 下訓練的,然後編譯 TensorFlow 1.4 的 C++ 介面可以正常呼叫模型,但若是想呼叫更高版本訓練好的模型,則會報錯,據出錯資訊猜測可能是高版本的 TensorFlow 中新增了一些低版本沒有的函式,所以不能正常執行。
- 若是編譯高版本的 TensorFlow ,比如最新的 TensorFlow 1.11 的 C++ 介面,則無論是呼叫舊版本訓練的模型還是新版本訓練的模型都不能正常執行。出錯資訊如下:Error loading checkpoint from /media/lab/data/yongsen/Tensorflow_test/test/model-40: Invalid argument: Session was not created with a graph before Run()!,網上暫時也查不到解決辦法,姑且先放在這裡。
6. 完整程式碼
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/io_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/parsing_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/array_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/math_ops.h>
#include </home/senius/tensorflow-r1.4/bazel-genfiles/tensorflow/cc/ops/data_flow_ops.h>
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
#include <fstream>
using namespace std;
using namespace tensorflow;
using namespace tensorflow::ops;
int main()
{
// set up your input paths
const string pathToGraph = "/home/senius/python/c_python/test/model-10.meta";
const string checkpointPath = "/home/senius/python/c_python/test/model-10";
auto session = NewSession(SessionOptions());
if (session == nullptr)
{
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok())
{
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok())
{
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run({{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor},}, {},
{graph_def.saver_def().restore_op_name()}, nullptr);
if (!status.ok())
{
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
cout << 1 << endl;
const string filename = "/home/senius/python/c_python/test/04t30t00.npy";
//Read TXT data to array
float Array[1681*41];
ifstream is(filename);
for (int i = 0; i < 1681*41; i++){
is >> Array[i];
}
is.close();
tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, 41, 41, 41, 1}));
auto input_tensor_mapped = input_tensor.tensor<float, 5>();
float *pdata = Array;
// copying the data into the corresponding tensor
for (int x = 0; x < 41; ++x)//depth
{
for (int y = 0; y < 41; ++y) {
for (int z = 0; z < 41; ++z) {
const float *source_value = pdata + x * 1681 + y * 41 + z;
// input_tensor_mapped(0, x, y, z, 0) = *source_value;
input_tensor_mapped(0, x, y, z, 0) = 1;
}
}
}
std::vector<tensorflow::Tensor> finalOutput;
std::string InputName = "X"; // Your input placeholder's name
std::string OutputName = "sigmoid"; // Your output placeholder's name
vector<std::pair<string, Tensor> > inputs;
inputs.push_back(std::make_pair(InputName, input_tensor));
// Fill input tensor with your input data
session->Run(inputs, {OutputName}, {}, &finalOutput);
auto output_y = finalOutput[0].scalar<float>();
std::cout << output_y() << "\n";
return 0;
}
複製程式碼
- Cmakelist 檔案如下
cmake_minimum_required(VERSION 3.8)
project(Tensorflow_test)
set(CMAKE_CXX_STANDARD 11)
set(SOURCE_FILES main.cpp)
include_directories(
/home/senius/tensorflow-r1.4
/home/senius/tensorflow-r1.4/tensorflow/bazel-genfiles
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/protobuf/include
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/host_obj
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/gen/proto
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/downloads/nsync/public
/home/senius/tensorflow-r1.4/tensorflow/contrib/makefile/downloads/eigen
/home/senius/tensorflow-r1.4/bazel-out/local_linux-py3-opt/genfiles
)
add_executable(Tensorflow_test ${SOURCE_FILES})
target_link_libraries(Tensorflow_test
/home/senius/tensorflow-r1.4/bazel-bin/tensorflow/libtensorflow_cc.so
/home/senius/tensorflow-r1.4/bazel-bin/tensorflow/libtensorflow_framework.so
)
複製程式碼
獲取更多精彩,請關注「seniusen」!