史上最簡潔使用Tensorflow_model_server

JimmyLeung1874發表於2018-11-14

最簡潔使用Tensorflow_model_server

Tensorflow_model_server的目的是:

統一管理一個模型伺服器,利於讓他人使用這個模型,而且可以動態更新模型,模型也會常住在記憶體裡面,加快結果輸出,減少模型載入時間。

事先準備操作:

Tensorflow_model_server 安裝:


參考:https://www.tensorflow.org/serving/setup

這裡會有一個坑: 直接pip intall tensorflow-model-server 的時候它會顯示已經安裝完,但是實際上還是沒有找到這個庫,處理方法是,先把原來的 tensorflow-model-server給uninstall 再按安裝庫上面安裝即可。


啟動 tensorflow_model_server :


tensorflow_model_server

--port=埠號 8000

--model_name=模型名稱 例:256

--model_base_path=絕對路徑 例: /notebooks/animieGan/TwinGAN/export

正常情況下:
正常圖片
異常情況:
它回報找不到模型,檢查下路徑是否是絕對路徑

這樣就啟動完了 tensorflow_model_server

編寫客戶端:

引入頭文檔案:
如果沒找到,就通過pip安裝:

import numpy as np
import scipy.misc
import tensorflow as tf

from grpc.beta import implementations
from tensorflow_serving.apis import predict_pb2  # pip install tensorflow-serving-api
from tensorflow_serving.apis import prediction_service_pb2
from PIL import Image
import sys

tf.app.flags.DEFINE_string('server', '10.11.32.51:8000',
                           'twingan_server host:port')
tf.flags.DEFINE_integer('gpu', -1,
                        'GPU ID (negative value indicates CPU)')
FLAGS = tf.app.flags.FLAGS
output_dir = './'

def main(_):
	image_hw = 256
	raw_image = Image.open('./78ae254bdcc94ffc5e8a921923a8d266.jpg')
	image_resized =  raw_image.resize((image_hw, image_hw), Image.BICUBIC)

	input_image =   np.expand_dims(image_resized / np.float32(255.0), 0)


	#input_image = np.reshape(int_image, 927696).astype(np.float32)
	host, port = FLAGS.server.split(':')
	print('GPU: {}'.format(FLAGS.gpu))
	channel = implementations.insecure_channel(host, int(port))
	
	stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
	request = predict_pb2.PredictRequest()
	request.model_spec.name = "256"
	request.model_spec.signature_name = "serving_default"
	request.inputs['inputs'].CopyFrom(tf.contrib.util.make_tensor_proto(input_image))

	result_future = stub.Predict.future(request, 15.0)  # 5 seconds
	exception = result_future.exception()
	if exception:
		print(exception)
	else:
		result_future.add_done_callback(doneJob(result_future))
	return 'hello wrod'

def doneJob(result_future):
	print("finish")
	sys.stdout.write('.')
	sys.stdout.flush()
        # TODO: do post-processing using another function.
	response_images = np.reshape(np.array(result_future.result().outputs['outputs'].float_val),[dim.size for dim in result_future.result().outputs['outputs'].tensor_shape.dim]) * 255.0
	scipy.misc.imsave('outfile2.jpg', response_images[0])


if __name__ == '__main__':
	tf.app.run()

	

核心是:

  • 1.構造一個 request = predict_pb2.PredictRequest()
  • 2.填充這個request的必要引數 比如image,資料
  • 3.通過 prediction_service_pb2 來處理 2中的 request
  • 4.接受流,生成 responedoneJob
  • 5.通過 tf.app.run() 啟動程式

測試模型與啟動程式碼:

模型地址https://drive.google.com/file/d/1C1tadCQjzsiW2GBeL8BbgfyKXKsoQCjJ/view
啟動程式碼 python model_server.py --twingan_server=10.11.32.51:8000 --image_hw=256 --gpu=1

相關文章