如何將keras訓練的模型轉換成tensorflow lite模型

colawarrior發表於2018-08-21

背景

keras是一個比較適合初學者上手的高階神經網路API,它能夠以TensorFlow, CNTK, 或者 Theano作為後端執行。而keras訓練完的模型是.h5檔案,如果想要在移動端執行模型需要tflite模型檔案

實現

附上從github上找到的一段轉換程式碼,但是要稍作修改

# coding: utf-8

# In[ ]:

'''
Input arguments:
num_output: this value has nothing to do with the number of classes, batch_size, etc., 
and it is mostly equal to 1. If the network is a **multi-stream network** 
(forked network with multiple outputs), set the value to the number of outputs.
quantize: if set to True, use the quantize feature of Tensorflow
(https://www.tensorflow.org/performance/quantization) [default: False]
use_theano: Thaeno and Tensorflow implement convolution in different ways.
When using Keras with Theano backend, the order is set to 'channels_first'.
This feature is not fully tested, and doesn't work with quantizization [default: False]
input_fld: directory holding the keras weights file [default: .]
output_fld: destination directory to save the tensorflow files [default: .]
input_model_file: name of the input weight file [default: 'model.h5']
output_model_file: name of the output weight file [default: args.input_model_file + '.pb']
graph_def: if set to True, will write the graph definition as an ascii file [default: False]
output_graphdef_file: if graph_def is set to True, the file name of the 
graph definition [default: model.ascii]
output_node_prefix: the prefix to use for output nodes. [default: output_node]
'''


# Parse input arguments

# In[ ]:

import argparse
parser = argparse.ArgumentParser(description='set input arguments')
parser.add_argument('-input_fld', action="store",
                    dest='input_fld', type=str, default='.')
parser.add_argument('-output_fld', action="store",
                    dest='output_fld', type=str, default='')
parser.add_argument('-input_model_file', action="store",
                    dest='input_model_file', type=str, default='model.h5')
parser.add_argument('-output_model_file', action="store",
                    dest='output_model_file', type=str, default='')
parser.add_argument('-output_graphdef_file', action="store",
                    dest='output_graphdef_file', type=str, default='model.ascii')
parser.add_argument('-num_outputs', action="store",
                    dest='num_outputs', type=int, default=1)
parser.add_argument('-graph_def', action="store",
                    dest='graph_def', type=bool, default=False)
parser.add_argument('-output_node_prefix', action="store",
                    dest='output_node_prefix', type=str, default='output_node')
parser.add_argument('-quantize', action="store",
                    dest='quantize', type=bool, default=False)
parser.add_argument('-theano_backend', action="store",
                    dest='theano_backend', type=bool, default=False)
parser.add_argument('-f')
args = parser.parse_args()
parser.print_help()
print('input args: ', args)

if args.theano_backend is True and args.quantize is True:
    raise ValueError("Quantize feature does not work with theano backend.")


# initialize

# In[ ]:

from keras.models import load_model
import tensorflow as tf
from pathlib import Path
from keras import backend as K
from keras.applications import mobilenet
from keras.utils.generic_utils import CustomObjectScope

output_fld =  args.input_fld if args.output_fld == '' else args.output_fld
if args.output_model_file == '':
    args.output_model_file = str(Path(args.input_model_file).name) + '.pb'
Path(output_fld).mkdir(parents=True, exist_ok=True)
weight_file_path = str(Path(args.input_fld) / args.input_model_file)

# Load keras model and rename output

# In[ ]:

K.set_learning_phase(0)
if args.theano_backend:
    K.set_image_data_format('channels_first')
else:
    K.set_image_data_format('channels_last')

# try:
# 主要修改在這裡,需要加上這行,否則會報錯
with CustomObjectScope({'relu6': mobilenet.relu6, 'DepthwiseConv2D': mobilenet.DepthwiseConv2D}):
    net_model = load_model(weight_file_path)
# except ValueError as err:
#     print('''Input file specified ({}) only holds the weights, and not the model defenition.
#     Save the model using mode.save(filename.h5) which will contain the network architecture
#     as well as its weights.
#     If the model is saved using model.save_weights(filename.h5), the model architecture is
#     expected to be saved separately in a json format and loaded prior to loading the weights.
#     Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
#           .format(weight_file_path))
#     raise err
num_output = args.num_outputs
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
    pred_node_names[i] = args.output_node_prefix+str(i)
    pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)


# [optional] write graph definition in ascii

# In[ ]:

sess = K.get_session()

if args.graph_def:
    f = args.output_graphdef_file
    tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
    print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))


# convert variables to constants and save

# In[ ]:

from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
    from tensorflow.tools.graph_transforms import TransformGraph
    transforms = ["quantize_weights", "quantize_nodes"]
    transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
    constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))
複製程式碼

keras轉tensorflow完成後,接下來我們就要將.pb檔案轉化為.tflite檔案。這裡查閱了很多資料,記錄一下坑的地方

  1. 如果你的tensorflow是1.8的話,先要將tensorflow升級到1.9或者降級到1.7,因為1.8的toco命令不好使。升級方法就是pip3 install -U tensorflow 或者pip3 install --upgrade tensorflow
  2. 升級完成後就可以使用toco命令了,注意:如果之前你是用virtualenv安裝的整個環境,那麼先source ./bin/activate啟用環境,在環境下才能使用
toco --graph_def_file mobilenet_v1_1.0_224_frozen.pb \
  --output_format=TFLITE \
  --output_file=mobilenet_v1_1.0_224_test.tflite \
  --inference_type=FLOAT \
  --input_arrays=input \
  --output_arrays=MobilenetV1/Predictions/Reshape_1 \
  --input_shapes=1,224,224,3
複製程式碼

這裡注意,千萬不要按照教程裡的命令進行,因為這裡有幾個坑點:

    1. 1.9的toco命令已經用引數--graph_def_file代替了--input_file
    1. 1.9的toco命令已經將引數--input_type取消掉

所以最後可以執行成功的命令,如上

  1. 上面的命令執行成功後,就可以將自己的pb檔案轉化成tflite檔案了,只要替換graph_def_file後面的pb檔名字和output_file後面的輸出檔名字,然後重點是知道你訓練的模型的input層的name和output層的name,至於怎麼找到這兩個層的name,最好用tensorflow中的load_graph函式load一下你的pb模型,遍歷graph找到對應層的name既可。分別用input層name和output層name替換input_arrays和output_arrays引數後面的值

相關文章