tensorflow機器學習模型的跨平臺上線

劉建平Pinard發表於2018-07-01

    在用PMML實現機器學習模型的跨平臺上線中,我們討論了使用PMML檔案來實現跨平臺模型上線的方法,這個方法當然也適用於tensorflow生成的模型,但是由於tensorflow模型往往較大,使用無法優化的PMML檔案大多數時候很笨拙,因此本文我們專門討論下tensorflow機器學習模型的跨平臺上線的方法。

1. tensorflow模型的跨平臺上線的備選方案

    tensorflow模型的跨平臺上線的備選方案一般有三種:即PMML方式,tensorflow serving方式,以及跨語言API方式。

    PMML方式的主要思路在上一篇以及講過。這裡唯一的區別是轉化生成PMML檔案需要用一個Java庫jpmml-tensorflow來完成,生成PMML檔案後,跨語言載入模型和其他PMML模型檔案基本類似。

    tensorflow serving是tensorflow 官方推薦的模型上線預測方式,它需要一個專門的tensorflow伺服器,用來提供預測的API服務。如果你的模型和對應的應用是比較大規模的,那麼使用tensorflow serving是比較好的使用方式。但是它也有一個缺點,就是比較笨重,如果你要使用tensorflow serving,那麼需要自己搭建serving叢集並維護這個叢集。所以為了一個小的應用去做這個工作,有時候會覺得麻煩。

    跨語言API方式是本文要討論的方式,它會用tensorflow自己的Python API生成模型檔案,然後用tensorflow的客戶端庫比如Java或C++庫來做模型的線上預測。下面我們會給一個生成生成模型檔案並用tensorflow Java API來做線上預測的例子。

2. 訓練模型並生成模型檔案

    我們這裡給一個簡單的邏輯迴歸並生成邏輯迴歸tensorflow模型檔案的例子。

    完整程式碼參見我的github:https://github.com/ljpzzz/machinelearning/blob/master/model-in-product/tensorflow-java

    首先,我們生成了一個6特徵,3分類輸出的4000個樣本資料。

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.datasets.samples_generator import make_classification
import tensorflow as tf
X1, y1 = make_classification(n_samples=4000, n_features=6, n_redundant=0,
                             n_clusters_per_class=1, n_classes=3)

    接著我們構建tensorflow的資料流圖,這裡要注意裡面的兩個名字,第一個是輸入x的名字input,第二個是輸出prediction_labels的名字output,這裡的這兩個名字可以自己取,但是後面會用到,所以要保持一致。

learning_rate = 0.01
training_epochs = 600
batch_size = 100

x = tf.placeholder(tf.float32, [None, 6],name='input') # 6 features
y = tf.placeholder(tf.float32, [None, 3]) # 3 classes

W = tf.Variable(tf.zeros([6, 3]))
b = tf.Variable(tf.zeros([3]))

# softmax迴歸
pred = tf.nn.softmax(tf.matmul(x, W) + b, name="softmax") 
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

prediction_labels = tf.argmax(pred, axis=1, name="output")

init = tf.global_variables_initializer()

    接著就是訓練模型了,程式碼比較簡單,畢竟只是一個演示:

sess = tf.Session()
sess.run(init)
y2 = tf.one_hot(y1, 3)
y2 = sess.run(y2)

for epoch in range(training_epochs):

    _, c = sess.run([optimizer, cost], feed_dict={x: X1, y: y2})
    if (epoch+1) % 10 == 0:
        print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c))
    
print ("優化完畢!")
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y2, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = sess.run(accuracy, feed_dict={x: X1, y: y2})
print (acc)

    列印輸出我這裡就不寫了,大家可以自己去試一試。接著就是關鍵的一步,存模型檔案了,注意要用convert_variables_to_constants這個API來儲存模型,否則模型引數不會隨著模型圖一起存下來。

graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
tf.train.write_graph(graph, '.', 'rf.pb', as_text=False)

    至此,我們的模型檔案rf.pb已經被儲存下來了,下面就是要跨平臺上線了。 

3. 模型檔案在Java平臺上線

    這裡我們以Java平臺的模型上線為例,C++的API上線我沒有用過,這裡就不寫了。我們需要引入tensorflow的java庫到我們工程的maven或者gradle檔案。這裡給出maven的依賴如下,版本可以根據實際情況選擇一個較新的版本。

        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.7.0</version>
        </dependency>

    接著就是程式碼了,這個程式碼會比JPMML的要簡單,我給出了4個測試樣本的預測例子如下,一定要注意的是裡面的input和output要和訓練模型的時候對應的節點名字一致。

import org.tensorflow.*;
import org.tensorflow.Graph;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;


/**
 * Created by 劉建平pinard on 2018/7/1.
 */
public class TFjavaDemo {
    public static void main(String args[]){
        byte[] graphDef = loadTensorflowModel("D:/rf.pb");
        float inputs[][] = new float[4][6];
        for(int i = 0; i< 4; i++){
            for(int j =0; j< 6;j++){
                if(i<2) {
                    inputs[i][j] = 2 * i - 5 * j - 6;
                }
                else{
                    inputs[i][j] = 2 * i + 5 * j - 6;
                }
            }
        }
        Tensor<Float> input = covertArrayToTensor(inputs);
        Graph g = new Graph();
        g.importGraphDef(graphDef);
        Session s = new Session(g);
        Tensor result = s.runner().feed("input", input).fetch("output").run().get(0);

        long[] rshape = result.shape();
        int rs = (int) rshape[0];
        long realResult[] = new long[rs];
        result.copyTo(realResult);

        for(long a: realResult ) {
            System.out.println(a);
        }
    }
    static private byte[] loadTensorflowModel(String path){
        try {
            return Files.readAllBytes(Paths.get(path));
        } catch (IOException e) {
            e.printStackTrace();
        }
        return null;
    }

    static private Tensor<Float> covertArrayToTensor(float inputs[][]){
        return Tensors.create(inputs);
    }
}

    我的預測輸出是1,1,0,0,供大家參考。

4. 一點小結

    對於tensorflow來說,模型上線一般選擇tensorflow serving或者client API庫來上線,前者適合於較大的模型和應用場景,後者則適合中小型的模型和應用場景。因此演算法工程師使用在產品之前需要做好選擇和評估。

 

(歡迎轉載,轉載請註明出處。歡迎溝通交流: liujianping-ok@163.com) 

相關文章