從Tensorflow模型檔案中解析並顯示網路結構圖(pb模型篇)

weixin_33727510發表於2018-11-07

Tensorflow官方提供的Tensorboard可以視覺化神經網路結構圖,但是說實話,我幾乎從來不用。主要是因為Tensorboard中檢視到的圖結構太混亂了,包含了網路中所有的計算節點(讀取資料節點、網路節點、loss計算節點等等)。更可怕的是,如果一個計算節點是由多個基礎計算(如加減乘除等)構成,那麼在Tensorboard中會將基礎計算節點顯示而不是作為一個整體顯示(典型的如Squeeze計算節點)。最近為了排查網路結構BUG花費一週時間,因此,狠下心來決定自己寫一個工具,將Tensorflow中的圖以最簡單的方式顯示最關鍵的網路結構。

1 Tensor物件與Operation物件

Tensorflow中,Tensor物件主要用於儲存資料如常量和變數(訓練引數),Operation物件是計算節點,如卷積計算、反摺積計算、ReLU等等。每一個Operation物件均有輸入和輸出Tensor,同理,每個Tensor物件均有對應生成該Tensor的Operation物件和使用該Tensor物件作為輸入的Operation物件。Tensor和Operation物件內均有相關屬性和函式來獲取其關聯的Operation和Tensor物件,相關屬性如下所示。

Tensor物件的op屬性指向生成該Tensor的Operation物件。
Tensor物件的consumers()函式獲取使用該Tensor物件作為輸入的Operation物件。
Operation物件的inputs屬性指向該計算節點的輸入Tensor物件。
Operation物件的outputs屬性執行該計算節點的輸出Tensor物件。

如下圖所示的網路結構中,呼叫Tensor_2物件的consumers()函式,返回的是[op_1,op_2]Tensor_3的op屬性指向的是op_1op_1的inputs屬性指向的是[Tensor_1,Tensor_2]op_1的output屬性指向的是[Tensor_3]

2154124-5acd786fb63d6a65.jpg
Tensor與Operation

有了Tensor與Operation對應在圖中的關聯關係,就可以將網路結構給畫出來。

2 提取pb檔案中的網路結構圖

pb檔案是將模型引數固化到圖檔案中,併合並了一些基礎計算和刪除了反向傳播相關計算得到的protobuf協議檔案。如果讀者還不懂如何將CKPT模型檔案轉pb檔案,請參考我另一篇文章《 Tensorflow MobileNet移植到Android》的第1節部分。有了pb模型檔案後,接下來是載入模型,載入pb模型示例程式碼如下所示。

def read_graph_from_pb(tf_model_path ,input_names,output_name):  
    with open(tf_model_path, 'rb') as f:
        serialized = f.read() 
    tf.reset_default_graph()
    gdef = tf.GraphDef()
    gdef.ParseFromString(serialized) 
    with tf.Graph().as_default() as g:
        tf.import_graph_def(gdef, name='') 
    
    with tf.Session(graph=g) as sess: 
        OPS=get_ops_from_pb(g,input_names,output_name)
    return OPS

其中,倒數第2行呼叫到的函式get_ops_from_pb()用於獲取網路結構圖中指定輸入節點和指定輸出節點之間的計算節點。之所以要指定輸入和輸出,是為了將輸入之前的計算節點(如載入資料佇列等相關計算節點)和輸出之後的計算節點(如計算loss等相關計算節點)去除,免得礙眼。函式get_ops_from_pb()實現程式碼如下。

def get_ops_from_pb(graph,input_names,output_name,save_ori_network=True):
    if save_ori_network:
        with open('ori_network.txt','w+') as w: 
            OPS=graph.get_operations()
            for op in OPS:
                txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
                w.write(txt+'\n') 
    inputs_tf = [graph.get_tensor_by_name(input_name) for input_name in input_names]
    output_tf =graph.get_tensor_by_name(output_name) 
    OPS =get_ops_from_inputs_outputs(graph, inputs_tf,[output_tf] ) 
    with open('network.txt','w+') as w: 
        for op in OPS:
            txt = str([v.name for v in op.inputs])+'---->'+op.type+'--->'+str([v.name for v in op.outputs])
            w.write(txt+'\n') 
    OPS = sort_ops(OPS)
    OPS = merge_layers(OPS)
    return OPS

在裁剪網路結構(即只保留input_names和output_name之間節點)之前,先將原始的網路結構寫入到ori_network.txt中,檔案中,每一行寫入:輸入Tensor---->op---->輸出Tensor。接下來呼叫函式get_ops_from_inputs_outputs獲取指定節點之間的節點。並呼叫sort_ops函式對所有的節點排序,以保證被依賴的節點總是出現在相關節點之前。最後呼叫merge_layers函式,將一些可以合併的計算合併成一個獨立的節點,例如,Squeeze計算相關節點合併成一個單獨的Squeeze節點,又如const-->identity兩個計算節點可以直接忽略(即刪除)。

注意:篇幅有限,這裡不再將函式get_ops_from_inputs_outputssort_opsmerge_layers貼出,相關程式碼請前往文尾提供的原始碼地址中閱讀。

3 繪製網路結構

考慮到SVG繪製圖形的簡單易用優點,將排好序的網路計算節點和相關Tensor物件資料以Javascript字串的形式寫入到HTML中,使用<line>標籤繪製箭頭,使用<rect>標籤繪製矩形,使用<ellipse>標籤繪製橢圓,使用<text>標籤顯示文字。繪製類似於如下所示影像

2154124-1439b9d90a5711b8.jpg
繪製網路結構示例

注意:篇幅有限,這裡不再介紹Javascript程式碼解析模型結構和SVG顯示相關的原理,相關程式碼請前往文尾提供的原始碼地址中閱讀。

4 測試模型顯示

《MobileNet V1官方預訓練模型的使用》文中介紹的MobileNet V1網路結構為例,下載MobileNet_v1_1.0_192檔案並壓縮後,得到mobilenet_v1_1.0_192_frozen.pb檔案。我們還需要知道mobilenet_v1_1.0_192_frozen.pb模型對應的輸入和輸出Tensor物件的名稱,好在MobileNet_v1_1.0_192壓縮包中包含檔案mobilenet_v1_1.0_192_info.txt。通過該檔案可知,輸入Tensor的名稱為:input:0,輸出Tensor名稱為:MobilenetV1/Predictions/Reshape_1:0。有了這些資訊後,呼叫函式read_graph_from_pb得到靜態圖的節點列表物件ops,呼叫函式gen_graph(ops,"save/path/graph.html")後,在目錄save/path中得到graph.html檔案,開啟graph.html後,顯示結果如下。

顯示網路結構分兩種模式:合併模式和展開模式,分別如下圖所示。

2154124-cb6173eaa8b27cf1.gif
合併模式網路結構
2154124-dd325392efe04f31.gif
擷取的展開模式網路結構

5 原始碼地址

https://github.com/huachao1001/CNNGraph

相關文章