tensorflow saver的問題

weixin_33866037發表於2018-05-25

參考資料

https://www.tensorflow.org/api_docs/python/tf/train/Saver

https://blog.csdn.net/u012436149/article/details/56665612

tensorflow學習筆記(三十四):Saver(儲存與載入模型)

2017年02月23日 09:41:40

閱讀數:5627

Saver

tensorflow 中的 Saver 物件是用於 引數儲存和恢復的。如何使用呢? 

這裡介紹了一些基本的用法。 

官網中給出了這麼一個例子:

v1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name='v2')# Pass the variables as a dict:saver = tf.train.Saver({'v1': v1,'v2': v2})# Or pass them as a list.saver = tf.train.Saver([v1, v2])# Passing a list is equivalent to passing a dict with the variable op names# as keys:saver = tf.train.Saver({v.op.name: vforvin[v1, v2]})#注意,如果不給Saver傳var_list 引數的話, 他將已 所有可以儲存的 variable作為其var_list的值。

這裡使用了三種不同的方式來建立 saver 物件, 但是它們內部的原理是一樣的。我們都知道,引數會儲存到 checkpoint檔案中,通過鍵值對的形式在 checkpoint中存放著。如果 Saver 的建構函式中傳的是 dict,那麼在 save 的時候,checkpoint檔案中存放的就是對應的 key-value。如下:

importtensorflowastf# Create some variables.v1 = tf.Variable(1.0, name="v1")v2 = tf.Variable(2.0, name="v2")saver = tf.train.Saver({"variable_1":v1,"variable_2": v2})# Use the saver object normally after that.withtf.Session()assess:    tf.global_variables_initializer().run()    saver.save(sess,'test-ckpt/model-2')


我們通過官方提供的工具來看一下 checkpoint 中儲存了什麼

fromtensorflow.python.tools.inspect_checkpointimportprint_tensors_in_checkpoint_fileprint_tensors_in_checkpoint_file("test-ckpt/model-2",None,True)# 輸出:#tensor_name:  variable_1#1.0#tensor_name:  variable_2#2.0

如果構建saver物件的時候,我們傳入的是 list, 那麼將會用對應 Variable 的 variable.op.name 作為 key。

importtensorflowastf# Create some variables.v1 = tf.Variable(1.0, name="v1")v2 = tf.Variable(2.0, name="v2")saver = tf.train.Saver([v1, v2])# Use the saver object normally after that.withtf.Session()assess:    tf.global_variables_initializer().run()    saver.save(sess,'test-ckpt/model-2')

我們再使用官方工具列印出 checkpoint 中的資料,得到

tensor_name:  v11.0tensor_name:  v22.0

如果我們現在想將 checkpoint 中v2的值restore到v1 中,v1的值restore到v2中,我們該怎麼做? 

這時,我們只能採用基於 dict 的 saver

importtensorflowastf# Create some variables.v1 = tf.Variable(1.0, name="v1")v2 = tf.Variable(2.0, name="v2")saver = tf.train.Saver({"variable_1":v1,"variable_2": v2})# Use the saver object normally after that.withtf.Session()assess:    tf.global_variables_initializer().run()    saver.save(sess,'test-ckpt/model-2')

save 部分的程式碼如上所示,下面寫 restore 的程式碼,和save程式碼有點不同。

```pythonimporttensorflowastf# Create some variables.v1 = tf.Variable(1.0, name="v1")v2 = tf.Variable(2.0, name="v2")#restore的時候,variable_1對應到v2,variable_2對應到v1,就可以實現目的了。saver = tf.train.Saver({"variable_1":v2,"variable_2": v1})# Use the saver object normally after that.withtf.Session()assess:    tf.global_variables_initializer().run()    saver.restore(sess,'test-ckpt/model-2')    print(sess.run(v1), sess.run(v2))# 輸出的結果是 2.0 1.0,如我們所望

我們發現,其實 建立 saver物件時使用的鍵值對就是表達了一種對應關係:

save時, 表示:variable的值應該儲存到 checkpoint檔案中的哪個 key下

restore時,表示:checkpoint檔案中key對應的值,應該restore到哪個variable

其它

一個快速找到ckpt檔案的方式

ckpt = tf.train.get_checkpoint_state(ckpt_dir)ifckptandckpt.model_checkpoint_path:    saver.restore(sess, ckpt.model_checkpoint_path)

相關文章