tf.train.NewCheckpointReader
儲存訓練模型的時候不僅持久化了計算圖結構,也持久化了變數的取值。
TensorFlow提供的tf.train.NewCheckpointReader
類來檢視儲存的變數的資訊。
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0,tf.float32, [1]),name='v1')
v2 = tf.Variable(tf.constant(2.0,tf.float32, [1]),name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(result.eval())
saver.save(sess,'./model/model.ckpt')
import tensorflow as tf
# tf.train.NewCheckpointReader可以讀取checkpoint檔案中儲存的所有變數
reader = tf.train.NewCheckpointReader('./model/model.ckpt')
# 獲取所有變數列表,是一個從變數名到變數維度的字典
global_variables = reader.get_variable_to_shape_map()
for variable_name in global_variables:
# variable_name為變數名稱,global_variables[variable_name]為變數維度
print(variable_name, global_variables[variable_name])
# 輸出: v2 [1]
# v1 [1]
print(reader.get_tensor('v1')) # 輸出: [1.]
print(reader.get_tensor('v2')) # 輸出: [2.]
# 獲取所有變數列表,是一個從變數名到變數資料型別的字典
global_variables1 = reader.get_variable_to_dtype_map()
for variable_name in global_variables1:
print(variable_name, global_variables1[variable_name])
# 輸出: v2 <dtype: 'float32'>
# v1 <dtype: 'float32'>
print(reader.get_tensor('v1')) # 輸出: [1.]
print(reader.get_tensor('v2')) # 輸出: [2.]