Tensorflow教程(2)Tensorflow的常用函式介紹

盛世芳華發表於2019-07-14

1、tf.constant

tf.constant用來定義一個常量,所謂常量,廣義上講就是“不變化的量”。我們先看下官方api是如何對constant函式來定義的:

tf.constant(
    value,
    dtype=None,
    shape=None,
    name='Const',
    verify_shape=False
)

其中包括5個輸入值:

value(必填):常量值,可以是一個數,也可以是一個向量或矩陣。

dtype(非必填):用來指定資料型別,例如tf.float32型別或tf.float64。

shape(非必填):用來指定資料的維度。

name(非必填):為常量定義名稱。

verify_shape(非必填):預設值為False,如果值為True時,在定義常量時會自動檢測value和shape維度是否相同,不同則報錯,例如value定義為1,而shape定義為一行兩列的矩陣(1,2),那麼肯定會報錯。

為了瞭解引數的具體含義,我們用程式碼來驗證一下吧!

#定義一個整數
a = tf.constant(1)
#定義一個向量
b = tf.constant([1,2])
#定義一個2行3列的矩陣
c = tf.constant([[1,2,3],[4,5,6]])
print(a)
print(b)
print(c)

輸出結果:

Tensor("Const:0", shape=(), dtype=int32)
Tensor("Const_1:0", shape=(2,), dtype=int32)
Tensor("Const_2:0", shape=(2, 3), dtype=int32)

變數a的shape為空,0個緯度,也就是一個數值;

變數b的shape是(2,),只有一個維度,是一個向量,長度為2;

變數c的shape是(2,3),有兩個維度,也就是一個2X3的矩陣。

當指定dtype引數時:

#定義一個整數
a = tf.constant(1,dtype=tf.float32)
#定義一個向量
b = tf.constant([1,2],dtype=tf.float32)
#定義一個2行3列的矩陣
c = tf.constant([[1,2,3],[4,5,6]],dtype=tf.float32)
print(a)
print(b)
print(c)

輸出結果:

Tensor("Const:0", shape=(), dtype=float32)
Tensor("Const_1:0", shape=(2,), dtype=float32)
Tensor("Const_2:0", shape=(2, 3), dtype=float32)

可見數值的型別都變為float32型別。

當指定shape引數時:

#定義一個整數
a = tf.constant(2.,shape=())
b = tf.constant(2.,shape=(3,))
c = tf.constant(2.,shape=(3,4))
with tf.Session() as sess:
    print(a.eval())
    print(b.eval())
    print(c.eval())

輸出結果:

2.0
[2. 2. 2.]
[[2. 2. 2. 2.]
 [2. 2. 2. 2.]
 [2. 2. 2. 2.]]

此時constant會根據shape指定的維度使用value值來進行填充,例如引數a指定維度為0,也就是一個整數;引數b指定維度為1長度為3,也就是一個向量;引數b指定維度為2長度為3X4,也就是定義一個3X4的矩陣,全部都使用value值2.0來進行填充。

當指定name引數時:

#不指定name
a = tf.constant(2.)
#指定name
b = tf.constant(2.,name="b")
print(a)
print(b)

輸出結果:

Tensor("Const:0", shape=(), dtype=float32)
Tensor("b:0", shape=(), dtype=float32)

建議大家建立常量時最好定義一下name,只要是字串就沒有問題。

當指定verify_shape=True時:

a = tf.constant(2.,shape=(2,3),verify_shape=True)

輸出結果報錯:

TypeError: Expected Tensor's shape: (2,3), got ().

錯誤原因是value的值和指定的shape維度不同,value是一個數值,而我們指定的shape為2X3的矩陣,所以報錯!當我們去掉verify_shape引數時錯誤即消失。那麼問題來了,此時這個常量到底是整數還是一個矩陣呢?當然是矩陣啦(一個被value值填充的2X3矩陣)!

相關文章