在 C/C++ 中使用 TensorFlow 預訓練好的模型—— 間接呼叫 Python 實現

seniusen發表於2019-03-03

現在的深度學習框架一般都是基於 Python 來實現,構建、訓練、儲存和呼叫模型都可以很容易地在 Python 下完成。但有時候,我們在實際應用這些模型的時候可能需要在其他程式語言下進行,本文將通過 C/C++ 間接呼叫 Python 的方式來實現在 C/C++ 程式中呼叫 TensorFlow 預訓練好的模型。

1. 環境配置

  • 為了能在 C/C++ 中呼叫 Python,我們需要配置一下標頭檔案和庫的路徑,本文以 Code::Blocks 為例介紹。

  • 在 Build -> Project options 新增連結庫 libpython3.5m.so 和標頭檔案 Python.h 所在的路徑,不同 Python 版本可以自己根據情況調整。

在這裡插入圖片描述

在這裡插入圖片描述

2. 初始化並匯入 Python 模組及相關函式

void Initialize()
{
    Py_Initialize();
    if ( !Py_IsInitialized() )
    {
        printf("Initialize failed!");
    }

	// Path of the python file. 需要更改為 python 檔案所在路徑
    PyRun_SimpleString("import sys");
    PyRun_SimpleString("sys.path.append('/home/senius/python/c_python/test/')");

    const char* modulName = "forward";    // Module name of python file.
    pMod = PyImport_ImportModule(modulName);
    if(!pMod)
    {
        printf("Import Module failed!\n");
    }

    const char* funcName = "load_model";  // Function name in the  python file.
    load_model = PyObject_GetAttrString(pMod, funcName);
    if(!load_model)
    {
        printf("Import load_model Function failed!\n");
    }

    funcName = "predict";  // Function name in the python file.
    predict = PyObject_GetAttrString(pMod, funcName);
    if(!predict)
    {
        printf("Import predict Function failed!\n");
    }

    PyEval_CallObject(load_model, NULL); // 匯入預訓練的模型
    pParm = PyTuple_New(1); // 新建一個元組,引數只能通過元組傳入 Python 程式
}

複製程式碼
  • 通過 PyImport_ImportModule 我們可以匯入需要呼叫的 Python 檔案,然後再通過 PyObject_GetAttrString 得到模組裡面的函式,最後匯入預訓練的模型並新建一個元組作為引數的傳入。

3. 構建從 C 傳入 Python 的引數

void Read_data()
{
    const char* txtdata_path = "/home/senius/python/c_python/test/04t30t00.npy";
    //Path of the TXT file. 需要更改為txt檔案所在路徑

    FILE *fp;
    fp = fopen(txtdata_path, "rb");
    if(fp == NULL)
    {
        printf("Unable to open the file!");
    }
    fread(data, num*SIZE, sizeof(float), fp);
    fclose(fp);

    // copying the data to the list
    int j = 0;
    pArgs = PyList_New(num * SIZE); // 新建一個列表,並填入資料
    while(j < num * SIZE)
    {
        PyList_SET_ITEM(pArgs, j, Py_BuildValue("f", data[j]));
        j++;
    }
}

複製程式碼
  • 讀入測試資料,並將資料填入到一個列表。

4. 將列表傳入元組,然後作為引數傳入 Python 中,並解析返回值

void Test()
{
    PyTuple_SetItem(pParm, 0, pArgs);
    pRetVal = PyEval_CallObject(predict, pParm);

    int list_len = PyList_Size(pRetVal);
    PyObject *list_item = NULL;
    PyObject *tuple_item = NULL;
    for (int i = 0; i < list_len; i++)
    {
        list_item = PyList_GetItem(pRetVal, i);
        tuple_item =  PyList_AsTuple(list_item);
        PyArg_ParseTuple(tuple_item, "f", &iRetVal[i]);
    }
}
複製程式碼
  • 傳入元組到 Python 程式,呼叫 predict 函式得到返回值,然後進行解析。

5. 一些引數和主函式

#include <Python.h>
#include <stdio.h>

#define SIZE 41*41*41*3
#define NUM 100

PyObject* pMod = NULL;
PyObject* load_model = NULL;
PyObject* predict = NULL;
PyObject* pParm = NULL;
PyObject* pArgs = NULL;
PyObject* pRetVal = NULL;

float iRetVal[NUM*3] = {0};
float data[NUM * SIZE] = {0};
int num = 1;  //實際的樣本數100

void Initialize(); 
void Read_data(); 
void Test(); 

int main(int argc, char **argv)
{
    Initialize(); // 初始化
    Read_data(); // 讀入資料
    Test(); // 呼叫預測函式並解析返回值
    
    int j = 0;
    while(j < num*3)
    {
        printf("%f\n", iRetVal[j]);
        j++;
    }
    printf("Done!\n");
    Py_Finalize();

    return 0;
}
複製程式碼

獲取更多精彩,請關注「seniusen」!

seniusen

相關文章