本文翻譯自:《Understanding static and dynamic shapes》, 如有侵權請聯絡刪除,僅限於學術交流,請勿商用。如有謬誤,請聯絡指出。
TensorFlow中的張量具有靜態形狀屬性,該屬性在圖形構造期間確定。靜態形狀可能未指定。例如,我們可以定義一個形狀張量[None,128]:
import tensorflow as tf
a = tf.placeholder(tf.float32, [None, 128])
複製程式碼
這意味著第一個維度可以是任何大小,並將在Session.run()
期間動態確定。您可以按如下方式查詢Tensor的靜態形狀:
static_shape = a.shape.as_list() # returns [None, 128]
複製程式碼
要獲得張量的動態形狀,可以呼叫tf.shape
方法,它返回一個給定張量代表的形狀:
dynamic_shape = tf.shape(a)
複製程式碼
可以使用Tensor.set_shape()
方法設定張量的靜態形狀:
a.set_shape([32, 128]) # static shape of a is [32, 128]
a.set_shape([None, 128]) # first dimension of a is determined dynamically
複製程式碼
您可以使用tf.reshape
函式動態重塑給定的張量:
a = tf.reshape(a, [32, 128])
複製程式碼
一個在可用時返回靜態形狀,不可用時返回動態形狀的函式會很方便。一個比較好的程式實現如下:
def get_shape(tensor):
static_shape = tensor.shape.as_list()
dynamic_shape = tf.unstack(tf.shape(tensor))
dims = [s[1] if s[0] is None else s[0]
for s in zip(static_shape, dynamic_shape)]
return dims
複製程式碼
現在假設我們想通過將第二維和第三維摺疊成一個來將三維張量轉換為二維張量。我們可以使用get_shape()
函式來做到這一點:
b = tf.placeholder(tf.float32, [None, 10, 32])
shape = get_shape(b)
b = tf.reshape(b, [shape[0], shape[1] * shape[2]])
複製程式碼
值得注意的是,無論形狀是否靜態指定,這都有效。
事實上,我們可以編寫一個通用的reshape()
函式來摺疊任何維度列表:
import tensorflow as tf
import numpy as np
def reshape(tensor, dims_list):
shape = get_shape(tensor)
dims_prod = []
for dims in dims_list:
if isinstance(dims, int):
dims_prod.append(shape[dims])
elif all([isinstance(shape[d], int) for d in dims]):
dims_prod.append(np.prod([shape[d] for d in dims]))
else:
dims_prod.append(tf.prod([shape[d] for d in dims]))
tensor = tf.reshape(tensor, dims_prod)
return tensor
複製程式碼
然後摺疊第二個維度就會變得非常容易:
b = tf.placeholder(tf.float32, [None, 10, 32])
b = reshape(b, [0, [1, 2]])
複製程式碼