tensorflow:一個簡單的python訓練儲存模型,java還原模型方法

witsmakemen發表於2018-04-24

總結一下這段時間學習使用tensorflow的一些經驗。主要應用場景是,使用python語言訓練一個簡單的LR模型,並且將模型以savedModel格式儲存模型,然後以python和java語言還原模型,預測結果。

python tensorflow版本和java client版本要對應,通過 tf.__version__ 檢視python的TensorFlow版本,java client版本就是jar包版本。

(1)訓練模型

import tensorflow as tf
import numpy as np

#生成訓練資料
x = np.ndarray(dtype=np.float32, shape=[4, 2])
x[0] = [1,1]
x[1] = [1,2]
x[2] = [1,3]
x[3] = [2,4]
print('====================')
print(x)
print(x.shape)
print(x.dtype)

#建立placeHolder作為輸入
x_inputs = tf.placeholder(tf.float32, shape=[None, 2])

#輸出結果
y_true = tf.constant([[2], [4], [5], [9]], dtype=tf.float32)

#單層神經網路,搭建LR模型
linear_model = tf.layers.Dense(units=1)
y_pred = linear_model(x_inputs)

#構建session
sess = tf.Session()
#儲存模型tensorbord視覺化結構的writer
writer = tf.summary.FileWriter("/Users/yourName/pythonworkspace/tmp/log", sess.graph)
#初始化變數
init = tf.global_variables_initializer()
sess.run(init)
#構建損失函式
loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)

#梯度下降優化器
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)

#開始訓練模型
print('================start=================')
for i in range(10000):
    _, loss_value = sess.run((train, loss), feed_dict={x_inputs:x})
    if i % 1000 == 0:
        print(loss_value)

#關閉視覺化writer,可以通過tensorboard --logdir /Users/yourName/pythonworkspace/tmp/log載入視覺化模型
writer.close()

#構建savedModel構建器
builder = tf.saved_model.builder.SavedModelBuilder("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel")
# x 為輸入tensor, keep_prob為dropout的prob tensor
inputs = {'input': tf.saved_model.utils.build_tensor_info(x_inputs)}

# y 為最終需要的輸出結果tensor
outputs = {'output': tf.saved_model.utils.build_tensor_info(y_pred)}

signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')
#儲存模型
builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature':signature})
builder.save()

(2)python 載入模型

import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
  #載入模型
  meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], "/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel")
  #載入模型簽名
  signature = meta_graph_def.signature_def
  print(signature)
  #從簽名中獲得張量名
  y_tensor_name = signature['test_signature'].outputs['output'].name
  x_tensor_name = signature['test_signature'].inputs['input'].name
  print(y_tensor_name)
  print(x_tensor_name)
  #還原張量
  y_pred = sess.graph.get_tensor_by_name(y_tensor_name)
  x_inputs = sess.graph.get_tensor_by_name(x_tensor_name)

  # 預測結果
  print(sess.run(y_pred, feed_dict={x_inputs:[[1,6]]}))

(3)java 載入模型
載入tensorflow依賴包

 <dependencies>
        <!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow -->
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.8.0-rc0</version>
        </dependency>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>proto</artifactId>
            <version>1.8.0-rc1</version>
        </dependency>
    </dependencies>

載入模型程式碼

import com.google.protobuf.InvalidProtocolBufferException;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;

import java.util.List;


public class Test {
    public static void main(String[] args) throws InvalidProtocolBufferException {

        /*載入模型 */
        SavedModelBundle savedModelBundle = SavedModelBundle.load("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel", "test_saved_model");
        /*構建預測張量*/
        float[][] matrix = new float[1][2];
        matrix[0][0] = 1;
        matrix[0][1] = 6;
        Tensor<Float> x = Tensor.create(matrix, Float.class);
        /*獲取模型簽名*/
        SignatureDef sig = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow("test_signature");
        String inputName = sig.getInputsMap().get("input").getName();
        System.out.println(inputName);
        String outputName = sig.getOutputsMap().get("output").getName();
        System.out.println(outputName);
        /*預測模型結果*/
        List<Tensor<?>> y = savedModelBundle.session().runner().feed(inputName, x).fetch(outputName).run();
        float [][] result = new float[1][1];
        System.out.println(y.get(0).dataType());
        System.out.println(y.get(0).copyTo(result));
        System.out.println(result[0][0]);

    }
}

相關文章