tf.shape()和tensor.get_shape()

Laox1ao發表於2018-04-11

問題

資料輸入的格式為

input = tf.placeholder([None,xxx,xxx],dtype=tf.float32)

需要得到batch的維度來進行中間Variable的初始化

val = tf.zeros([batch_size,xxx,xxx],dtype=tf.float32)

方法

可行:

val = tf.zeros([tf.shape(input)[0],xxx,xxx],dtype=tf.float32)

失敗:

val = tf.zeros([input.get_shape()[0],xxx,xxx],dtype=tf.float32)
val = tf.zeros([input.shape()[0],xxx,xxx],dtype=tf.float32)

錯誤提示:

ValueError: Cannot convert a partially known TensorShape to a Tensor

原因:
可能是由於get_shape()返回的是元組,tf.shape()返回的是tensor,所以tf.shape()返回的為None的batch_size維度可以繼續作為其他tensor的維度,而get_shape()由於返回的是元組,取到的batch_size維度直接是None了,無法作為其他tensor的維度。

相關文章