tf.shape()和getshape的區別

獨鹿發表於2018-05-08

參考資料:https://blog.csdn.net/chenxieyy/article/details/53020760

1,tf.shape(a)和a.get_shape()比較

   相同點:都可以得到tensor a的尺寸

   不同點:tf.shape()中a 資料的型別可以是tensor, list, array

                 a.get_shape()中a的資料型別只能是tensor,且返回的是一個元組(tuple)

2.獲取tensor的維度x.get_shape().with_rank(3),3為tensor維度,返回的還是一個元組和不加的時候一樣,如果加了with_rank但維數不對會報錯。如果維度錯誤也會報錯

3.tf.stack()這是一個矩陣拼接的函式(裡面傳兩個值,第一個是要拼接的列表,第二個是從哪個方向拼接,0表示豎直方向,把列表中的物件逐個新增到一個列表中,1表示橫著拼接,把元素中的對應值拼接完加到列表中,之多元素對應值都拼接完。),

list_ = [[1, 2, 3], [4, 5, 6]]

print(type(list_))
sess=tf.Session()
a = tf.stack(list_, 0)
print sess.run(a)
a = tf.stack(list_, 1)

print sess.run(a)

[[1 2 3]
 [4 5 6]]
[[1 4]
 [2 5]

 [3 6]]

4.tf.unstack()則是一個矩陣分解的函式(第一個值傳矩陣,第二個值傳分解的方向,如果為0將矩陣以豎直方向存放,如果為1將矩陣個列拼接為1個元素,然後存成列表)。

[array([1, 2, 3], dtype=int32), array([4, 5, 6], dtype=int32)]
[array([1, 4], dtype=int32), array([2, 5], dtype=int32), array([3, 6], dtype=int32)]

2,例子:

import tensorflow as tf
import numpy as np

x=tf.constant([[1,2,3],[4,5,6]])
y=[[1,2,3],[4,5,6]]
z=np.arange(24).reshape([2,3,4])

sess=tf.Session()
# tf.shape()
x_shape=tf.shape(x)
y_shape=tf.shape(y)
z_shape=tf.shape(z)
print sess.run(x_shape)
print sess.run(y_shape)
print sess.run(z_shape)

#a.get_shape()

x_shape=x.get_shape()
print x_shape
x_shape=x.get_shape().as_list()
print x_shape
z_shape =z.get_shape()
print z_shape
# y_shape=y.get_shape()  # AttributeError: 'list' object has no attribute 'get_shape'
# z_shape=z.get_shape()  # AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'

而list,numpy的物件是沒有getshape的屬性的,只有tensor有getshape.

getshape()返回一個元組,通過aslist()方法可以把一個元組轉為list

x_shape=x.get_shape()
print x_shape

x_shape=x.get_shape().as_list()

(2, 3)
[2, 3]


相關文章