tensorflow模型持久化儲存和載入
模型檔案的儲存
tensorflow將模型保持到本地會生成4個檔案:
meta檔案:儲存了網路的圖結構,包含變數、op、集合等資訊
ckpt檔案: 二進位制檔案,儲存了網路中所有權重、偏置等變數數值,分為兩個檔案,一個是.data-00000-of-00001 檔案,一個是 .index 檔案
checkpoint檔案:文字檔案,記錄了最新保持的5個模型檔案列表tf中模型儲存使用 tf.train.Saver類來儲存模型。使用方式:
1. 在Session外生成一個模型儲存物件
saver = tf.train.Saver()
2. 在Session中以當前環境Session為引數,儲存模型到本地磁碟
saver.save(sess,"./model/Model_test")
Saver類的建構函式定義:
def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
常用的幾個變數:- var_list: 指定要儲存的變數的序列或字典,預設為None,儲存所有變數
- reshape: 可選引數,如果為True,表示允許變數以不同的形狀儲存,如果為False,表示保持的變數只能有同樣一種形狀和資料型別,預設為False;
- max_to_keep: 定義最多儲存最近的多少個模型檔案,預設是5個;
- keep_checkpoint_every_n_hours: 定義多少個小時儲存模型一次,預設10000個小時;
- name: 可選引數,新增到操作名稱前的字首,預設None;
- restore_sequentially:定義在裝置上是否按照順序恢復變數,順序恢復可以降低內參使用,預設False;
- saver_def:可選引數,用在需要重建Saver物件場合,預設None;
- allow_empty:是否允許儲存一個沒有任何變數的空圖,預設False;
saver.save函式定義:
def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True,
strip_default_attrs=False):
常用引數:- sess: 當前的會話環境;
- save_path: 模型儲存路徑;
- global_step: 訓練輪次,如果新增,會在模型檔名稱後加上這個輪次的字尾,預設None,不新增,最好設定這個引數,要不然模型檔案就會由於重名覆蓋掉之前儲存的;
- latest_filename: checkpoint文字檔案的名稱,預設為‘checkpoint’
- meta_graph_suffix: 儲存的網路圖結構檔案的字尾,預設為mata;
- write_meta_graph: 定義是否儲存網路結構,預設是True儲存,由於網路結構在訓練過程中是不會變的,所以儲存過一次後可以設定 write_meta_graph為False,不用每次都儲存圖結構;
簡單示例,以下程式中 X和Y是一個含有128個元素的列表,每個元素是一個二維陣列,定義公式 Y = (X*w1+b1)*(w2)+b2 ,使用tensorflow網路迭代求 w和b 的最優解,完成之後保持模型到本地 model_test 資料夾。
# -*- coding: utf-8 -*-)
import tensorflow as tf
from numpy.random import RandomState
# 定義訓練資料batch的大小
batch_size = 8
# 在shape上使用None表示該維度的具體數值不定
x = tf.placeholder(tf.float32, shape=(None, 2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')
# 定義神經網路的引數
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
bias1 = tf.Variable(tf.random_normal([3], stddev=1, seed=1))
bias2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
# 定義神經網路前向傳播的過程,即操作
a = tf.nn.relu(tf.matmul(x, w1) + bias1)
y = tf.nn.relu(tf.matmul(a, w2) + bias2)
# 定義損失函式和反向傳播演算法
loss = tf.reduce_sum(tf.pow((y - y_), 2))
train_step = tf.train.AdamOptimizer(0.001).minimize(loss) # 梯度下降優化演算法
# produce the data,通過隨機數生成一個模擬資料集
rdm = RandomState(seed=1) # 設定seed = 1 ,使每次生成的隨機數一樣
dataset_size = 128
X = rdm.rand(dataset_size, 2)
Y = [[x1 + 10 * x2] for (x1, x2) in X]
# 生成一個保持模型物件
saver = tf.train.Saver()
# creare a session,建立一個會話來執行TensorFlow程式
with tf.Session() as sess:
# 初始化變數
sess.run(tf.global_variables_initializer())
# 設定訓練的輪數
STEPS = 10000
for i in range(STEPS + 1):
# get batch_size samples data to train,每次選取batch_size個樣本進行訓練
start = (i * batch_size) % dataset_size
end = min(start + batch_size, dataset_size)
# 通過選取的樣本訓練神經網路並更新引數
sess.run(train_step, feed_dict={x: X[start: end], y_: Y[start: end]})
if i % 500 == 0:
# 每隔一段時間計算在所有資料上的loss並輸出
total_cross_entropy= sess.run([loss], feed_dict={x: X, y_: Y})
print ("steps: {}, total loss: {}".format(i,total_cross_entropy))
# 在訓練結束之後,保持神經網路模型
saver.save(sess, "./model_saved/model_test")
print sess.run((w1,bias1))
print('^^^^^^^^^^^^^^^^^^^^^^^^^')
print sess.run((w2,bias2))
# output:
# steps: 0, total loss: [2599.938]
# steps: 500, total loss: [873.66064]
# steps: 1000, total loss: [667.79114]
# steps: 1500, total loss: [483.07538]
# steps: 2000, total loss: [300.2436]
# steps: 2500, total loss: [159.57596]
# steps: 3000, total loss: [74.0152]
# steps: 3500, total loss: [30.022282]
# steps: 4000, total loss: [10.848581]
# steps: 4500, total loss: [3.8684735]
# steps: 5000, total loss: [1.6775348]
# steps: 5500, total loss: [0.87090385]
# steps: 6000, total loss: [0.47393078]
# steps: 6500, total loss: [0.2628175]
# steps: 7000, total loss: [0.13229856]
# steps: 7500, total loss: [0.058554076]
# steps: 8000, total loss: [0.022747971]
# steps: 8500, total loss: [0.007896027]
# steps: 9000, total loss: [0.002599821]
# steps: 9500, total loss: [0.0007222026]
# steps: 10000, total loss: [0.00021833208]
# (array([[-0.8113182 , 0.741788 , -0.06654923],
# [-2.4427042 , 1.7258024 , 3.505848 ]], dtype=float32), array([-0.8113182 , 0.9206883 , -0.00473781], dtype=float32))
# ^^^^^^^^^^^^^^^^^^^^^^^^^
# (array([[-0.8113182],
# [ 1.5360606],
# [ 2.0962803]], dtype=float32), array([-1.4044524], dtype=float32))
經過10000次迭代之後完成訓練,在本地model_test目錄下建立了模型的4個檔案:
模型檔案的載入
模型檔案的圖結構跟資料是分開儲存的,載入模型時候可以先載入圖結構,再載入圖中的引數(在Session中操作):
saver=tf.train.import_meta_graph('./model_saved/model_test.meta')
saver.restore(sess, tf.train.latest_checkpoint('./model_saved'))
或者一次性載入:
saver = tf.train.Saver()
saver.restore(sess, './model_saved/model_test')
或:
saver.restore(sess, tf.train.latest_checkpoint('./model_saved'))
‘model_test’是儲存的模型檔名稱(字首名,不帶字尾)
更加安全一點的載入方式,先判斷模型檔案是否存在判斷(推薦使用這種方式):
ckpt = tf.train.get_checkpoint_state('./model_saved')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# -*- coding: utf-8 -*-)
import tensorflow as tf
from numpy.random import RandomState
# 定義訓練資料batch的大小
batch_size = 8
# 在shape上使用None表示該維度的具體數值不定
x = tf.placeholder(tf.float32, shape=(None, 2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(None, 1), name='y-input')
# 定義神經網路的引數
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
bias1 = tf.Variable(tf.random_normal([3], stddev=1, seed=1))
bias2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
# 定義神經網路前向傳播的過程,即操作
a = tf.nn.relu(tf.matmul(x, w1) + bias1)
y = tf.nn.relu(tf.matmul(a, w2) + bias2)
# produce the data,通過隨機數生成一個模擬資料集
rdm = RandomState(seed=1) # 設定seed = 1 ,使每次生成的隨機數一樣
dataset_size = 128
X = rdm.rand(dataset_size, 2)
Y = [[x1 + 10 * x2] for (x1, x2) in X]
# creare a session,建立一個會話來執行TensorFlow程式
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./model_saved/model_test.meta')
saver.restore(sess, tf.train.latest_checkpoint('./model_saved'))
# 初始化變數
sess.run(tf.global_variables_initializer())
print(sess.run(y,feed_dict={x: X[0: 10], y_: Y[0: 10]}))
# output:
# [[2.4518511]
# [1.4534602]
# [1.7382364]
# [1.8725655]
# [2.3733683]
# [2.4501202]
# [2.0117776]
# [1.582149 ]
# [2.4224167]
# [1.7438407]]
tf.train.Saver常用函式列表:
操作 | 描述 |
---|---|
類tf.train.Saver(Saving and Restoring Variables) | |
tf.train.Saver.__init__(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None) | 建立一個儲存器Saver var_list定義需要儲存和恢復的變數 |
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix=’meta’, write_meta_graph=True) | 儲存變數 |
tf.train.Saver.restore(sess, save_path) | 恢復變數 |
tf.train.Saver.last_checkpoints | 列出最近未刪除的checkpoint 檔名 |
tf.train.Saver.set_last_checkpoints(last_checkpoints) | 設定checkpoint檔名列表 |
tf.train.Saver.set_last_checkpoints_with_time(last_checkpoints_with_time) | 設定checkpoint檔名列表和時間戳 |
相關文章
- Tensorflow SavedModel模型的儲存與載入模型
- Tensorflow模型的儲存與恢復載入模型
- spacy儲存和載入模型模型
- TensorFlow模型儲存和提取方法模型
- Pytorch | Tutorial-07 儲存和載入模型PyTorch模型
- Redis持久化儲存Redis持久化
- Room-資料持久化儲存(入門)OOM持久化
- Transformers 儲存並載入模型 | 八ORM模型
- 儲存載入模型model.save()模型
- Flutter持久化儲存之檔案儲存Flutter持久化
- Flutter持久化儲存之資料庫儲存Flutter持久化資料庫
- Flutter持久化儲存之key-value儲存Flutter持久化
- 訓練模型的儲存與載入模型
- tf.keras: 儲存與載入模型Keras模型
- Python 載入 TensorFlow 模型Python模型
- scrapy框架持久化儲存框架持久化
- Redis 持久化儲存詳解Redis持久化
- Redis持久化儲存——>RDB & AOFRedis持久化
- (三)Kubernetes---持久化儲存持久化
- 在 Python 中儲存和載入機器學習模型Python機器學習模型
- 1.05 docker的持久化儲存和資料共享Docker持久化
- Docker的持久化儲存和資料共享(四)Docker持久化
- iOS資料持久化儲存-CoreDataiOS持久化
- Kubernetes 持久化資料儲存 StorageClass持久化
- Kubuesphere部署Ruoyi(三):持久化儲存配置持久化
- Kubernetes的故事之持久化儲存(十)持久化
- Docker容器中資料兩種持久化儲存方式:卷和掛載宿主目錄Docker持久化
- [PyTorch 學習筆記] 7.1 模型儲存與載入PyTorch筆記模型
- pytorch-模型儲存與載入自己訓練的模型詳解PyTorch模型
- AOF持久化(儲存的是操作redis命令)持久化Redis
- 利用Kubernetes實現容器的持久化儲存持久化
- 容器雲對接持久化儲存並使用持久化
- TensorFlow 載入多個模型的方法模型
- 使用容器化塊儲存OpenEBS在K3s中實現持久化儲存持久化
- Pytorch模型檔案`*.pt`與`*.pth` 的儲存與載入PyTorch模型
- 全面解析Pytorch框架下模型儲存,載入以及凍結PyTorch框架模型
- 機器學習之儲存與載入.pickle模型檔案機器學習模型
- 【小白學PyTorch】19 TF2模型的儲存與載入PyTorchTF2模型