使用mmdnn將MXNET轉成Tensorflow模型

dupei發表於2019-12-10

參考:https://www.twblogs.net/a/5ca4cadbbd9eee5b1a0713af

寫這篇筆記的初衷是因為需要將Insightface的MXNET模型轉成Tensorflow,而mmdnn上提供的mmconvert命令卻沒有成功(估計是哪裡用錯了),故而找其他的文章來實現相同的功能。

準備mmdnn環境

因為需要將mxnet轉成tensorflow,所以,需要在python的虛擬環境中安裝這兩個框架,另外,建議使用最新的mmdnn,修復了一下bug。

pip install -U git+https://github.com/Microsoft/MMdnn.git@master
pip install mmdnn
pip install tensorflow==1.13.1

將mxnet模型轉成IR模型

下面的引數中–inputShape很重要,insightface的輸入影像的shape就是[3,112,112]

python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-0000.params -d r100 --inputShape 3,112,112

下面就是生成的IR模型
在這裡插入圖片描述

由IR模型生成code

python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath r100.pb --IRWeightPath r100.npy --dstModelPath tf_r100.py

使用上面的命令,可以生成一個tf_r100.py,其中,KitModel函式載入npy權重引數重新生成原網路框架。

生成PB檔案

基於r100.npy和tf_r100.py文​​件,固化引數,生成PB檔案。

import argparse
import tensorflow as tf
import tf_r100 as tf_fun # modify by yourself

def Network(ir_model_file):
    model = tf_fun.KitModel(ir_model_file)
    return model

def freeze_graph(ir_model_file, pb_file):
    output_node_names = "output"
    _,fc1 = Network(ir_model_file)
    fc1 = tf.identity(fc1, name="output")

    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=input_graph_def,
            output_node_names=output_node_names.split(","))

        with tf.gfile.GFile(pb_file, "wb") as f:
            f.write(output_graph_def.SerializeToString())

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='freeze model')
    parser.add_argument('ir_model_file', help='path to ir model, which is *.npy')
    parser.add_argument('pb_file', help='path to pb file')
    args = parser.parse_args()
    freeze_graph(args.ir_model_file, args.pb_file)
    print("finish!")

顯示模型的網路

import argparse
import tensorflow as tf

def display_network(pb_file):
    # after load model from file
    with tf.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, input_map=None, return_elements=None, name="")

    # print out operations information
    ops = tf.get_default_graph().get_operations()
    for op in ops:
        print(op.name) # print operation name
        print('> ', op.values()) # print a list of tensors it produces

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='display network')
    parser.add_argument('pb_file', help='path to pb file')
    args = parser.parse_args()
    display_network(args.pb_file)

相關文章