tf.keras: 儲存與載入模型

BOQTATQ發表於2020-12-14

tf.keras.save 和 tf.keras.models.load_model 問題

使用 tf.keras.models.load_model 載入模型後會出現準確率極低的問題, 就好像沒有進行訓練過一樣. 這是因為我們在模型中使用了 v1.x 優化器 (來自 tf.compat.v1.train), 而此類優化器由於與檢查點不相容, 所以在載入模型時會丟失優化器的狀態值. 我們只能通過重新編譯模型來恢復優化器的狀態.

本文如何講述

本文用到的庫
構建模型
訓練模型
儲存模型
載入模型
模型評估

本文用到的庫

# 這是本文用到的庫
import tensorflow as tf
import pathlib

當下有一名為 model 的已訓練完的採用了 v1.x 優化器 ‘adam’ 的模型.

構建模型

model = tf.keras.Sequential([
    layers.experimental.preprocessing.Rescaling(1 / 255),
    layers.Conv2D(32, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10)
])
model.compile(
	optimizer='adam',  # 這裡的優化器 'adam' 就是一個 v1.x 優化器
	loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
	metrics=['accuracy']
)

訓練模型

# 假設我們已經有了一個名為 train_set 的訓練集, 使用該訓練集對模型進行訓練(5個輪次).
model.fit(train_set, epoch=5)

儲存模型

我們想把 model 儲存到當前工作目錄下, 檔名為 model.h5 (HDF5 格式), 步驟如下:

首先, 設定模型儲存的路徑:

# 使用 pathlib.Path 類構造路徑, 可以免受平臺差異性的困擾
model_path = pathlib.Path(r'assets/model.h5') 

然後, 將模型儲存成檔案.

model = model.save(model_path)

載入模型

# 匯入模型
new_model = tf.keras.models.load_model(model_path)

# 需再次編譯該模型, 編譯條件須與訓練時相同
model.compile(
	optimizer='adam',
	loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
	metrics=['accuracy']	   
)

評估模型

假設我們已經有了一個測試集 test_set, 我們可以使用此測試集對模型進行評估:

model.evaluate(test_set)

總結

由於 tf.compat.v1.train 裡的優化器不相容檢查點, 所以我們後續載入模型時會丟失優化器的狀態值, 導致後續評估結果過差. 我們可通過重新編譯的方式解決此問題.

相關文章