TensorFlow 載入多個模型的方法

spearhead_cai發表於2018-11-18

採用 TensorFlow 的時候,有時候我們需要載入的不止是一個模型,那麼如何載入多個模型呢?

原文:bretahajek.com/2017/04/imp…


關於 TensorFlow 可以有很多東西可以說。但這次我只介紹如何匯入訓練好的模型(圖),因為我做不到匯入第二個模型並將它和第一個模型一起使用。並且,這種匯入非常慢,我也不想重複做第二次。另一方面,將一切東西都放到一個模型也不實際。

在這個教程中,我會介紹如何儲存和載入模型,更進一步,如何載入多個模型。

載入 TensorFlow 模型

在介紹載入多個模型之前,我們先介紹下如何載入單個模型,官方文件:www.tensorflow.org/programmers…

首先,我們需要建立一個模型,訓練並儲存它。這部分我不想過多介紹細節,只需要關注如何儲存模型以及不要忘記給每個操作命名。

建立一個模型,訓練並儲存的程式碼如下:

import tensorflow as tf
### Linear Regression 線性迴歸###
# Input placeholders
x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')
# Model parameters 定義模型的權值引數
W1 = tf.Variable([0.1], tf.float32)
W2 = tf.Variable([0.1], tf.float32)
W3 = tf.Variable([0.1], tf.float32)
b = tf.Variable([0.1], tf.float32)

# Output 模型的輸出
linear_model = tf.identity(W1 * x + W2 * x**2 + W3 * x**3 + b,
                           name='activation_opt')

# Loss 定義損失函式
loss = tf.reduce_sum(tf.square(linear_model - y), name='loss')
# Optimizer and training step 定義優化器運算
optimizer = tf.train.AdamOptimizer(0.001)
train = optimizer.minimize(loss, name='train_step')

# Remember output operation for later aplication
# Adding it to a collections for easy acces
# This is not required if you NAME your output operation
# 記得將輸出操作新增到一個集合中,但如何你命名了輸出操作,這一步可以省略
tf.add_to_collection("activation", linear_model)

## Start the session ##
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#  CREATE SAVER
saver = tf.train.Saver()

# Training loop 訓練
for i in range(10000):
    sess.run(train, {x: data, y: expected})
    if i % 1000 == 0:
        # You can also save checkpoints using global_step variable
        saver.save(sess, "models/model_name", global_step=i)

# SAVE TensorFlow graph into path models/model_name
# 儲存模型到指定路徑並命名模型檔名字
saver.save(sess, "models/model_name")
複製程式碼

注意,這裡是第一個重點--對變數和運算命名。這是為了在載入模型後可以使用指定的一些權值引數,如果不命名的話,這些變數會自動命名為類似“Placeholder_1”的名字。在複雜點的模型中,使用領域(scopes)是一個很好的做法,但這裡不做展開。

總之,==重點就是為了在載入模型的時候能夠呼叫權值引數或者某些運算操作,你必須給他們命名或者是放到一個集合中。==

當儲存模型後,在指定儲存模型的資料夾中就應該包含這些檔案:model_name.indexmodel_name.meta以及其他檔案。如果是採用checkpoints字尾命名模型名字,還會有名字包含model_name-1000的檔案,其中的數字是對應變數global_step,也就是當前訓練迭代次數。

現在我們就可以開始載入模型了。載入模型其實很簡單,我們需要的只是兩個函式即可:tf.train.import_meta_graphsaver.restore()。此外,就是提供正確的模型儲存路徑位置。另外,如果我們希望在不同機器使用模型,那麼還需要設定引數:clear_device=True

接著,我們就可以通過之前命名的名字或者是儲存到的集合名字來呼叫儲存的運算或者是權值引數了。如果使用了領域,那麼還需要包含領域的名字才行。而在實際呼叫這些運算的時候,還必須採用類似{'PlaceholderName:0': data}的輸入佔位符,否則會出現錯誤。

載入模型的程式碼如下:

sess = tf.Session()

# Import graph from the path and recover session
# 載入模型並恢復到會話中
saver = tf.train.import_meta_graph('models/model_name.meta', clear_devices=True)
saver.restore(sess, 'models/model_name')

# There are TWO options how to access the operation (choose one)
# 兩種方法來呼叫指定的運算操作,選擇其中一個都可以
  # FROM SAVED COLLECTION: 從儲存的集合中呼叫
activation = tf.get_collection('activation')[0]
  # BY NAME: 採用命名的方式
activation = tf.get_default_graph.get_operation_by_name('activation_opt').outputs[0]

# Use imported graph for data
# You have to feed data as {'x:0': data}
# Don't forget on ':0' part!
# 採用載入的模型進行操作,不要忘記輸入佔位符
data = 50
result = sess.run(activation, {'x:0': data})
print(result)
複製程式碼

多個模型

上述介紹瞭如何載入單個模型的操作,但如何載入多個模型呢?

如果使用載入單個模型的方式去載入多個模型,那麼就會出現變數衝突的錯誤,也無法工作。這個問題的原因是因為一個預設圖的緣故。衝突的發生是因為我們將所有變數都載入到當前會話採用的預設圖中。當我們採用會話的時候,我們可以通過tf.Session(graph=MyGraph)來指定採用不同的已經建立好的圖。因此,如果我們希望載入多個模型,那麼我們需要做的就是把他們載入在不同的圖,然後在不同會話中使用它們。

這裡,自定義一個類來完成載入指定路徑的模型到一個區域性圖的操作。這個類還提供run函式來對輸入資料使用載入的模型進行操作。這個類對於我是有用的,因為我總是將模型輸出放到一個集合或者對它命名為activation_opt,並且將輸入佔位符命名為x。你可以根據自己實際應用需求對這個類進行修改和擴充。

程式碼如下:

import tensorflow as tf

class ImportGraph():
    """  Importing and running isolated TF graph """
    def __init__(self, loc):
        # Create local graph and use it in the session
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)
        with self.graph.as_default():
            # Import saved model from location 'loc' into local graph
            # 從指定路徑載入模型到區域性圖中
            saver = tf.train.import_meta_graph(loc + '.meta',
                                               clear_devices=True)
            saver.restore(self.sess, loc)
            # There are TWO options how to get activation operation:
            # 兩種方式來呼叫運算或者引數
              # FROM SAVED COLLECTION:            
            self.activation = tf.get_collection('activation')[0]
              # BY NAME:
            self.activation = self.graph.get_operation_by_name('activation_opt').outputs[0]

    def run(self, data):
        """ Running the activation operation previously imported """
        # The 'x' corresponds to name of input placeholder
        return self.sess.run(self.activation, feed_dict={"x:0": data})
      
      
### Using the class ###
# 測試樣例
data = 50         # random data
model = ImportGraph('models/model_name')
result = model.run(data)
print(result)
複製程式碼

總結

如果你理解了 TensorFlow 的機制的話,載入多個模型並不是一件困難的事情。上述的解決方法可能不是完美的,但是它簡單且快速。最後給出總結整個過程的樣例程式碼,這是在 Jupyter notebook 上的,程式碼地址如下:

gist.github.com/Breta01/f20…


最後,給出文章中幾個程式碼例子的 github 地址:

  1. Code for creating, training and saving TensorFlow model.
  2. Importing and using TensorFlow graph (model)
  3. Class for importing multiple TensorFlow graphs.
  4. Example of importing multiple TensorFlow modules

歡迎關注我的微信公眾號--機器學習與計算機視覺或者掃描下方的二維碼,在後臺留言,和我分享你的建議和看法,指正文章中可能存在的錯誤,大家一起交流,學習和進步!

TensorFlow 載入多個模型的方法

推薦閱讀

1.機器學習入門系列(1)--機器學習概覽(上)

2.機器學習入門系列(2)--機器學習概覽(下)

3.[GAN學習系列] 初識GAN

4.[GAN學習系列2] GAN的起源

5.谷歌開源的 GAN 庫--TFGAN

相關文章