tf.tile()進行張量擴充套件

Candy_GL發表於2018-07-23

tf.tile()進行張量擴充套件

tf.tile()應用於需要張量擴充套件的場景,具體說來就是: 
如果現有一個形狀如[widthheight]的張量,需要得到一個基於原張量的,形狀如[batch_size,width,height]的張量,其中每一個batch的內容都和原張量一模一樣。tf.tile使用方法如:

tile(
    input,
    multiples,
    name=None
)
  • 1
  • 2
  • 3
  • 4
  • 5

其中輸出將會重複input輸入multiples次。例子如:

import tensorflow as tf

raw = tf.Variable(tf.random_normal(shape=(1, 3, 2)))
multi = tf.tile(raw, multiples=[2, 1, 1])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(raw.eval())
    print('-----------------------------')
    print(sess.run(multi))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

輸出如:

[[[-0.50027871 -0.48475555]
  [-0.52617502 -0.2396145 ]
  [ 1.74173343 -0.20627949]]]
-----------------------------
[[[-0.50027871 -0.48475555]
  [-0.52617502 -0.2396145 ]
  [ 1.74173343 -0.20627949]]

 [[-0.50027871 -0.48475555]
  [-0.52617502 -0.2396145 ]
  [ 1.74173343 -0.20627949]]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

可見,multi重複了raw的0 axes兩次,1和2 axes不變。

相關文章