tf.keras: 儲存與載入模型
tf.keras.save 和 tf.keras.models.load_model 問題
使用 tf.keras.models.load_model 載入模型後會出現準確率極低的問題, 就好像沒有進行訓練過一樣. 這是因為我們在模型中使用了 v1.x 優化器 (來自 tf.compat.v1.train), 而此類優化器由於與檢查點不相容, 所以在載入模型時會丟失優化器的狀態值. 我們只能通過重新編譯模型來恢復優化器的狀態.
本文如何講述
本文用到的庫
# 這是本文用到的庫
import tensorflow as tf
import pathlib
設當下有一名為 model
的已訓練完的採用了 v1.x 優化器 ‘adam’ 的模型.
構建模型
model = tf.keras.Sequential([
layers.experimental.preprocessing.Rescaling(1 / 255),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10)
])
model.compile(
optimizer='adam', # 這裡的優化器 'adam' 就是一個 v1.x 優化器
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics=['accuracy']
)
訓練模型
# 假設我們已經有了一個名為 train_set 的訓練集, 使用該訓練集對模型進行訓練(5個輪次).
model.fit(train_set, epoch=5)
儲存模型
我們想把 model 儲存到當前工作目錄下, 檔名為 model.h5 (HDF5 格式), 步驟如下:
首先, 設定模型儲存的路徑:
# 使用 pathlib.Path 類構造路徑, 可以免受平臺差異性的困擾
model_path = pathlib.Path(r'assets/model.h5')
然後, 將模型儲存成檔案.
model = model.save(model_path)
載入模型
# 匯入模型
new_model = tf.keras.models.load_model(model_path)
# 需再次編譯該模型, 編譯條件須與訓練時相同
model.compile(
optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
評估模型
假設我們已經有了一個測試集 test_set
, 我們可以使用此測試集對模型進行評估:
model.evaluate(test_set)
總結
由於 tf.compat.v1.train 裡的優化器不相容檢查點, 所以我們後續載入模型時會丟失優化器的狀態值, 導致後續評估結果過差. 我們可通過重新編譯的方式解決此問題.
相關文章
- Tensorflow SavedModel模型的儲存與載入模型
- 訓練模型的儲存與載入模型
- Tensorflow模型的儲存與恢復載入模型
- 【tf.keras】tf.keras載入AlexNet預訓練模型Keras模型
- spacy儲存和載入模型模型
- Pytorch模型檔案`*.pt`與`*.pth` 的儲存與載入PyTorch模型
- Transformers 儲存並載入模型 | 八ORM模型
- 儲存載入模型model.save()模型
- [PyTorch 學習筆記] 7.1 模型儲存與載入PyTorch筆記模型
- pytorch-模型儲存與載入自己訓練的模型詳解PyTorch模型
- tensorflow模型持久化儲存和載入模型持久化
- 機器學習之儲存與載入.pickle模型檔案機器學習模型
- 【小白學PyTorch】19 TF2模型的儲存與載入PyTorchTF2模型
- Pytorch | Tutorial-07 儲存和載入模型PyTorch模型
- 2.影像的載入與儲存
- 在 Python 中儲存和載入機器學習模型Python機器學習模型
- 全面解析Pytorch框架下模型儲存,載入以及凍結PyTorch框架模型
- PyTorch儲存模型斷點以及載入斷點繼續訓練PyTorch模型斷點
- DAOS 分散式非同步物件儲存|儲存模型分散式非同步物件模型
- Java 載入、操作和儲存WPS文字文件Java
- 大模型儲存實踐:效能、成本與多雲大模型
- RocketMQ(十):資料儲存模型設計與實現MQ模型
- 儲存過程與儲存函式儲存過程儲存函式
- TensorFlow模型儲存和提取方法模型
- 掌握Hive資料儲存模型Hive模型
- 資料中心儲存 TCO 模型模型
- 使用Spark載入資料到SQL Server列儲存表SparkSQLServer
- C++之OpenCV入門到提高002:載入、修改、儲存影像C++OpenCV
- 載入模型模型
- Gartner:浪潮儲存進入分散式儲存前三分散式
- Oracle 共享儲存掛載Oracle
- Spark SQL使用簡介(3)--載入和儲存資料SparkSQL
- opencv學習筆記(二)-- 載入、修改和儲存影像OpenCV筆記
- 雲原生儲存詳解:容器儲存與 K8s 儲存卷K8S
- 學習筆記14:模型儲存筆記模型
- 機器學習-訓練模型的儲存與恢復(sklearn)機器學習模型
- MySQL入門--儲存引擎MySql儲存引擎
- 【C語言進階】通訊錄的儲存和載入C語言