雖然說 TensorFlow 2.0 即將問世,但是有一些模組的內容卻是不大變化的。其中就有 tf.saved_model 模組,主要用於模型的儲存和恢復。為了防止學習記錄檔案丟失或者蠢笨的腦子直接遺忘掉這部分內容,在此做點簡單的記錄,以便將來查閱。
最近為了一個課程作業,不得已涉及到關於影像超解析度恢復的內容,不得不準備隨時儲存訓練的模型,只好再回過頭來瞄一眼 TensorFlow 文件,真是太痛苦了。
tf.saved_model 模組下面有很多檔案和函式,精力有限,只好選擇於自己有用的東西來看,可能並不全面,望日後補上。
其中最重要的就是該模組下的一個類:tf.saved_model.builder.SavedModelBuilder
tf.saved_model.builder.SavedModelBuilder: # 建構函式 .__init__(export_dir) """ 作用: 建立一個儲存模型的例項物件 引數: export_dir: 模型匯出路徑,由於 TensorFlow 會在你指定的路徑上建立資料夾和檔案,所以指定的路徑最後不需要帶 /, 例如:export_dir=`/home/***/saved_model` 即可,最後不需要加上 / """ # 方法 # 1 .add_meta_graph_and_variables(sess, tags, signature_def_map=None, assets_collection=None, clear_devices=False, main_op=None, strip_default_attrs=False, saver=None) """ 作用: 儲存會話物件中的 graph 和所有變數,具體描述可參見文件 引數: sess: TensorFlow 會話物件,用於儲存元圖和變數 tags: 用於儲存元圖的標記集(如果存在多個圖物件,需要設定保證每個圖示籤不一樣),是一個列表 signature_def_map: 一個字典,儲存模型時傳入的引數,key 可以是字串,也可以是 tf.saved_model.signature_constants 檔案下預定義的變數, 值為 signatureDef protobuf(protobuf 是一種結構化的資料儲存格式) assets_collection: 略 clear_devices: 如果需要清除預設圖上的裝置資訊,則設定為 true main_op: 這個引數包括後面一系列與其相關的東西沒有弄明白 strip_default_attrs: 如果設定為 True,將從 NodeDefs 中刪除預設值屬性 saver: tf.train.Saver 的一個例項,用於匯出元圖並儲存變數 """ # 2 .add_meta_graph() """ 作用: 其除了沒有 sess 引數以外,其他引數和 .add_meta_graph_and_variables() 一模一樣 呼叫此方法之前必須先呼叫 .add_meta_graph_and_variables() 方法 """ # 3 .save(as_text=False) """ 作用: 將內建的 savedModel protobuf 寫入磁碟 """
除了這個最重要的類以外,tf.saved_model 模組還提供了一些方便構建 builder 和載入模型的函式方法。
# 1 tf.saved_model.utils.build_tensor_info(tensor) """ 作用: 構建 TensorInfo protobuf,根據輸入的 tensor 構建相應的 protobuf,返回的 TensorInfo 中包含輸入 tensor 的 name,shape,dtype 資訊 引數: tensor: Tensor 或 SparseTensor """ # 2 tf.saved_model.signature_def_utils.build_signature_def(inputs=None, outputs=None, method_name=None) """ 作用: 構建 SignatureDef protobuf,並返回 SignatureDef protobuf 引數: inputs: 一個字典,鍵為字串型別,值為關於 tensor 的資訊,也就是上述的 .build_tensor_info() 函式返回的 TensorInfo protobuf outputs: 一個字典,同上 method_name: SignatureDef 名稱 """ # 3 tf.saved_model.utils.get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None) """ 作用: 根據一個 TensorInfo protobuf 解析出一個 tensor 引數: tensor_info: 一個 TensorInfo protobuf graph: tensor 所存在的 graph,引數為 None 時,使用預設圖 import_scope: 給 tensor 的 name 加上字首 """ # 4 tf.saved_model.loader.load(sess, tags, export_dir, import_scope=None, **saver_kwargs) """ 作用: 載入已儲存的模型 引數: sess: 用於恢復模型的 tf.Session() 物件 tags: 用於標識 MetaGraphDef 的標記,應該和儲存模型時使用的此引數完全一致 export_dir: 模型儲存路徑 import_scope: 加字首 """
除了這些以外,還有一些 TensorFlow 為了方便而預定義的一些變數,這些變數完全可以使用自定義字串代替,不再贅述。詳情:https://tensorflow.google.cn/api_docs/python/tf/saved_model
如果只看這些內容的話,確實會使人產生巨大的疑惑,下面是具體實踐的例子:
import tensorflow as tf from tensorflow import saved_model as sm # 首先定義一個極其簡單的計算圖 X = tf.placeholder(tf.float32, shape=(3, )) scale = tf.Variable([10, 11, 12], dtype=tf.float32) y = tf.multiply(X, scale) # 在會話中執行 with tf.Session() as sess: sess.run(tf.initializers.global_variables()) value = sess.run(y, feed_dict={X: [1., 2., 3.]}) print(value) # 準備儲存模型 path = `/home/×××/tf_model/model_1` builder = sm.builder.SavedModelBuilder(path) # 構建需要在新會話中恢復的變數的 TensorInfo protobuf X_TensorInfo = sm.utils.build_tensor_info(X) scale_TensorInfo = sm.utils.build_tensor_info(scale) y_TensorInfo = sm.utils.build_tensor_info(y) # 構建 SignatureDef protobuf SignatureDef = sm.signature_def_utils.build_signature_def( inputs={`input_1`: X_TensorInfo, `input_2`: scale_TensorInfo}, outputs={`output`: y_TensorInfo}, method_name=`what` ) # 將 graph 和變數等資訊寫入 MetaGraphDef protobuf # 這裡的 tags 裡面的引數和 signature_def_map 字典裡面的鍵都可以是自定義字串,TensorFlow 為了方便使用,不在新地方將自定義的字串忘記,可以使用預定義的這些值 builder.add_meta_graph_and_variables(sess, tags=[sm.tag_constants.TRAINING], signature_def_map={sm.signature_constants.CLASSIFY_INPUTS: SignatureDef} ) # 將 MetaGraphDef 寫入磁碟 builder.save()
這樣我們就把模型整體儲存到了磁碟中,而且我們將三個變數 X, scale, y 全部序列化後儲存到了其中,所以恢復模型時便可以將他們完全解析出來:
import tensorflow as tf from tensorflow import saved_model as sm # 需要建立一個會話物件,將模型恢復到其中 with tf.Session() as sess: path = `/home/×××/tf_model/model_1` MetaGraphDef = sm.loader.load(sess, tags=[sm.tag_constants.TRAINING], export_dir=path) # 解析得到 SignatureDef protobuf SignatureDef_d = MetaGraphDef.signature_def SignatureDef = SignatureDef_d[sm.signature_constants.CLASSIFY_INPUTS] # 解析得到 3 個變數對應的 TensorInfo protobuf X_TensorInfo = SignatureDef.inputs[`input_1`] scale_TensorInfo = SignatureDef.inputs[`input_2`] y_TensorInfo = SignatureDef.outputs[`output`] # 解析得到具體 Tensor # .get_tensor_from_tensor_info() 函式中可以不傳入 graph 引數,TensorFlow 自動使用預設圖 X = sm.utils.get_tensor_from_tensor_info(X_TensorInfo, sess.graph) scale = sm.utils.get_tensor_from_tensor_info(scale_TensorInfo, sess.graph) y = sm.utils.get_tensor_from_tensor_info(y_TensorInfo, sess.graph) print(sess.run(scale)) print(sess.run(y, feed_dict={X: [3., 2., 1.]})) # 輸出 [10. 11. 12.] [30. 22. 12.]
可以看出模型整體和變數個體都被完整地儲存了下來。其中涉及的關於 protobuf 的知識,需要補習,在 TensorFlow 中好多地方都用到了相關的知識。上述恢復模型的程式碼中對具體的 TensorInfo protobuf 解析時,還可以使用另一種方式得到相應的 Tensor:
# 已知 X_TensorInfo, scale_TensorInfo, y_TensorInfo X = sess.graph.get_tensor_by_name(X_TensorInfo.name) scale = sess.grpah.get_tensor_by_name(scale_TensorInfo.name) y = sess.graph.get_tensor_by_name(y_TensorInfo.name) # 因為 TensorFlow 構建 TensorInfo protobuf 時,使用了 Tensor 的 name 資訊,所以可以直接讀出來使用