tensorflow:一個簡單的python訓練儲存模型,java還原模型方法
總結一下這段時間學習使用tensorflow的一些經驗。主要應用場景是,使用python語言訓練一個簡單的LR模型,並且將模型以savedModel格式儲存模型,然後以python和java語言還原模型,預測結果。
python tensorflow版本和java client版本要對應,通過 tf.__version__ 檢視python的TensorFlow版本,java client版本就是jar包版本。
(1)訓練模型
import tensorflow as tf
import numpy as np
#生成訓練資料
x = np.ndarray(dtype=np.float32, shape=[4, 2])
x[0] = [1,1]
x[1] = [1,2]
x[2] = [1,3]
x[3] = [2,4]
print('====================')
print(x)
print(x.shape)
print(x.dtype)
#建立placeHolder作為輸入
x_inputs = tf.placeholder(tf.float32, shape=[None, 2])
#輸出結果
y_true = tf.constant([[2], [4], [5], [9]], dtype=tf.float32)
#單層神經網路,搭建LR模型
linear_model = tf.layers.Dense(units=1)
y_pred = linear_model(x_inputs)
#構建session
sess = tf.Session()
#儲存模型tensorbord視覺化結構的writer
writer = tf.summary.FileWriter("/Users/yourName/pythonworkspace/tmp/log", sess.graph)
#初始化變數
init = tf.global_variables_initializer()
sess.run(init)
#構建損失函式
loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)
#梯度下降優化器
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
#開始訓練模型
print('================start=================')
for i in range(10000):
_, loss_value = sess.run((train, loss), feed_dict={x_inputs:x})
if i % 1000 == 0:
print(loss_value)
#關閉視覺化writer,可以通過tensorboard --logdir /Users/yourName/pythonworkspace/tmp/log載入視覺化模型
writer.close()
#構建savedModel構建器
builder = tf.saved_model.builder.SavedModelBuilder("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel")
# x 為輸入tensor, keep_prob為dropout的prob tensor
inputs = {'input': tf.saved_model.utils.build_tensor_info(x_inputs)}
# y 為最終需要的輸出結果tensor
outputs = {'output': tf.saved_model.utils.build_tensor_info(y_pred)}
signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')
#儲存模型
builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature':signature})
builder.save()
(2)python 載入模型
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
#載入模型
meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], "/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel")
#載入模型簽名
signature = meta_graph_def.signature_def
print(signature)
#從簽名中獲得張量名
y_tensor_name = signature['test_signature'].outputs['output'].name
x_tensor_name = signature['test_signature'].inputs['input'].name
print(y_tensor_name)
print(x_tensor_name)
#還原張量
y_pred = sess.graph.get_tensor_by_name(y_tensor_name)
x_inputs = sess.graph.get_tensor_by_name(x_tensor_name)
# 預測結果
print(sess.run(y_pred, feed_dict={x_inputs:[[1,6]]}))
(3)java 載入模型
載入tensorflow依賴包
<dependencies>
<!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.8.0-rc0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>proto</artifactId>
<version>1.8.0-rc1</version>
</dependency>
</dependencies>
載入模型程式碼
import com.google.protobuf.InvalidProtocolBufferException;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import java.util.List;
public class Test {
public static void main(String[] args) throws InvalidProtocolBufferException {
/*載入模型 */
SavedModelBundle savedModelBundle = SavedModelBundle.load("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel", "test_saved_model");
/*構建預測張量*/
float[][] matrix = new float[1][2];
matrix[0][0] = 1;
matrix[0][1] = 6;
Tensor<Float> x = Tensor.create(matrix, Float.class);
/*獲取模型簽名*/
SignatureDef sig = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow("test_signature");
String inputName = sig.getInputsMap().get("input").getName();
System.out.println(inputName);
String outputName = sig.getOutputsMap().get("output").getName();
System.out.println(outputName);
/*預測模型結果*/
List<Tensor<?>> y = savedModelBundle.session().runner().feed(inputName, x).fetch(outputName).run();
float [][] result = new float[1][1];
System.out.println(y.get(0).dataType());
System.out.println(y.get(0).copyTo(result));
System.out.println(result[0][0]);
}
}
相關文章
- TensorFlow模型儲存和提取方法模型
- 訓練模型的儲存與載入模型
- TensorFlow 呼叫預訓練好的模型—— Python 實現模型Python
- 用 Java 訓練深度學習模型,原來可以這麼簡單!Java深度學習模型
- pytorch-模型儲存與載入自己訓練的模型詳解PyTorch模型
- 如何將keras訓練的模型轉換成tensorflow lite模型Keras模型
- Tensorflow SavedModel模型的儲存與載入模型
- TensorFlow2.0教程-使用keras訓練模型Keras模型
- 機器學習-訓練模型的儲存與恢復(sklearn)機器學習模型
- Tensorflow模型的儲存與恢復載入模型
- 如何藉助分散式儲存 JuiceFS 加速 AI 模型訓練分散式UIAI模型
- 記錄:tf.saved_model 模組的簡單使用(TensorFlow 模型儲存與恢復)模型
- tensorflow模型持久化儲存和載入模型持久化
- 一文詳解TensorFlow模型遷移及模型訓練實操步驟模型
- 使用PaddleFluid和TensorFlow訓練序列標註模型UI模型
- 訓練一個目標檢測模型模型
- 昇騰遷移丨4個TensorFlow模型訓練案例解讀模型
- 怎麼訓練出一個NB的Prophet模型模型
- TensorFlow 載入多個模型的方法模型
- PreSTU:一個專門為場景文字理解而設計的簡單預訓練模型REST模型
- PyTorch儲存模型斷點以及載入斷點繼續訓練PyTorch模型斷點
- 獲取和生成基於TensorFlow的MobilNet預訓練模型模型
- skmultiflow使用自己的csv檔案訓練模型並儲存實驗結果模型
- 「NLP」GPT:第一個引入Transformer的預訓練模型GPTORM模型
- keras轉tensorflow lite【方法二】直接轉:簡單模型例項Keras模型
- 海南話語音識別模型——模型訓練(一)模型
- 【LLM訓練】從零訓練一個大模型有哪幾個核心步驟?大模型
- 利用Python訓練手勢模型程式碼Python模型
- 在 C/C++ 中使用 TensorFlow 預訓練好的模型—— 間接呼叫 Python 實現C++模型Python
- 預訓練語言模型:還能走多遠?模型
- 常見預訓練語言模型簡述模型
- 1890美元,就能從頭訓練一個還不錯的12億引數擴散模型模型
- Python 載入 TensorFlow 模型Python模型
- MxNet預訓練模型到Pytorch模型的轉換模型PyTorch
- Question | 標註下資料、訓練個模型,商用的智慧鑑黃有這麼簡單嗎?模型
- 監控大模型訓練大模型
- PyTorch預訓練Bert模型PyTorch模型
- fasttext訓練模型程式碼AST模型