tf.train.NewCheckpointReader

GarryLau發表於2020-11-05

儲存訓練模型的時候不僅持久化了計算圖結構,也持久化了變數的取值。
TensorFlow提供的tf.train.NewCheckpointReader類來檢視儲存的變數的資訊。

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0,tf.float32, [1]),name='v1')
v2 = tf.Variable(tf.constant(2.0,tf.float32, [1]),name='v2')

result = v1 + v2
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(result.eval())
    saver.save(sess,'./model/model.ckpt')

在這裡插入圖片描述


import tensorflow as tf

# tf.train.NewCheckpointReader可以讀取checkpoint檔案中儲存的所有變數
reader = tf.train.NewCheckpointReader('./model/model.ckpt')

# 獲取所有變數列表,是一個從變數名到變數維度的字典
global_variables = reader.get_variable_to_shape_map()
for variable_name in global_variables:
    # variable_name為變數名稱,global_variables[variable_name]為變數維度
    print(variable_name, global_variables[variable_name])
     # 輸出: v2 [1]
     #      v1 [1]
print(reader.get_tensor('v1'))  # 輸出: [1.]
print(reader.get_tensor('v2'))  # 輸出: [2.]

# 獲取所有變數列表,是一個從變數名到變數資料型別的字典
global_variables1 = reader.get_variable_to_dtype_map()
for variable_name in global_variables1:
    print(variable_name, global_variables1[variable_name])
     # 輸出: v2 <dtype: 'float32'>
     #      v1 <dtype: 'float32'>
print(reader.get_tensor('v1'))  # 輸出: [1.]
print(reader.get_tensor('v2'))  # 輸出: [2.]