TF2.1進階-tf.booleanmask_where_scatter_nd_mashgrid

加油噹噹發表於2020-12-20
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
print(tf.__version__)
2.1.0
# 4 35 8  4個班級,每個班有35名學生,每名學生有8個科目要上
x = tf.random.uniform([4,35,8],maxval = 100,dtype=tf.int32)
x
<tf.Tensor: shape=(4, 35, 8), dtype=int32, numpy=
array([[[64, 17,  0, ..., 95, 78, 21],
        [92, 60, 30, ...,  0, 86, 71],
        [40, 23, 69, ..., 90, 74, 31],
        ...,
        [15,  0, 52, ..., 68, 17,  6],
        [48, 54, 28, ..., 99, 24,  2],
        [ 2,  0, 17, ..., 18, 74,  0]],

       [[42, 88, 70, ..., 83, 74, 64],
        [79, 92, 32, ..., 93, 79, 28],
        [20, 76, 26, ..., 54, 96, 80],
        ...,
        [81, 66, 18, ..., 84, 88, 11],
        [30, 72, 38, ...,  8, 94, 43],
        [24, 81, 19, ..., 53, 55, 14]],

       [[43, 44, 26, ..., 11, 10, 31],
        [44, 98, 23, ..., 53, 88, 81],
        [70, 29, 73, ..., 62, 77,  1],
        ...,
        [73, 57, 26, ...,  7, 56, 74],
        [44, 25,  3, ..., 73, 61, 63],
        [55, 42, 50, ..., 35, 19, 34]],

       [[58, 17, 39, ...,  6, 22, 27],
        [86, 49, 51, ..., 50, 25, 19],
        [75, 60,  4, ..., 88, 67, 49],
        ...,
        [83, 47, 36, ..., 32, 14, 45],
        [95,  3, 40, ..., 35, 14, 23],
        [78, 77, 77, ..., 92,  8,  7]]])>

使用掩碼的方法採取資料

tf.boolean_mask()

掩碼的長度,需要和維度相等,四個維度上採取長度需要有四個布林值

# 掩碼的長度,需要和維度相等,四個維度上採取長度需要有四個布林值
mask = [True,False,False,True]
tf.boolean_mask(x,mask,axis=0)
<tf.Tensor: shape=(2, 35, 8), dtype=int32, numpy=
array([[[64, 17,  0, 20, 72, 95, 78, 21],
        [92, 60, 30, 36, 33,  0, 86, 71],
        [40, 23, 69, 50, 47, 90, 74, 31],
        [ 8, 17, 44, 91, 27, 86, 26, 27],
        [15, 16,  6, 29, 98,  5, 17, 65],
        [ 4, 86, 32, 26, 42, 23, 76, 52],
        [75, 82, 53, 98, 56, 32, 92, 34],
        [81, 55, 17, 47, 80, 64, 40, 79],
        [73, 40,  7, 85, 60, 67, 92, 81],
        [32, 58,  3, 48, 49, 85,  4, 51],
        [12, 55, 55, 12, 43, 30, 10, 11],
        [25, 38, 19, 97, 99, 32, 28, 36],
        [30, 71, 76, 75, 77, 66, 30, 70],
        [31,  4,  1, 82, 13,  5, 16, 62],
        [90, 43,  5, 36,  8, 54, 88, 45],
        [96, 49, 80, 14, 23, 79, 55, 17],
        [ 2, 96, 55,  8, 89, 87, 56, 37],
        [98, 42, 50, 71, 55, 14, 52, 57],
        [56, 18, 61, 52, 39, 69, 51, 28],
        [37, 34, 38, 52, 28, 51, 74, 64],
        [74, 56, 47, 84, 92, 34, 89, 65],
        [73, 87,  2, 43,  7, 43, 63, 57],
        [26, 83, 94, 42, 22, 95, 81, 57],
        [89, 54, 16,  3, 99, 49, 40,  1],
        [72, 76,  5, 99, 40, 20, 11, 91],
        [62, 98, 57, 64, 42, 13, 25, 39],
        [26, 53, 42, 12, 73, 60, 38, 35],
        [44, 10, 52, 23, 32, 23,  7, 42],
        [36, 10, 84, 37, 40, 53, 90, 61],
        [85, 48, 76, 43, 96, 87,  4, 75],
        [51, 54, 44,  2, 96, 69, 78, 27],
        [99, 31, 97, 10,  6, 21, 41, 88],
        [15,  0, 52, 36, 25, 68, 17,  6],
        [48, 54, 28, 63, 70, 99, 24,  2],
        [ 2,  0, 17, 51, 82, 18, 74,  0]],

       [[58, 17, 39, 15, 24,  6, 22, 27],
        [86, 49, 51, 12, 96, 50, 25, 19],
        [75, 60,  4, 64, 46, 88, 67, 49],
        [23, 87, 42, 85, 86, 31, 48, 76],
        [18, 69, 24, 52, 19,  7, 86, 14],
        [11, 74, 64, 57, 52, 29, 57, 35],
        [25, 14,  1, 25,  3, 38, 38, 44],
        [33, 16, 70, 53, 14,  1, 99, 63],
        [70, 44, 36, 83, 69, 99, 90, 72],
        [34, 90, 87, 95, 20, 38, 48, 55],
        [51, 77, 32, 30, 35, 11, 96, 52],
        [88, 28, 90, 96, 39, 44, 45, 59],
        [17, 79, 34, 18, 10, 86, 31, 70],
        [54, 48, 84, 53, 61, 81, 36, 22],
        [85, 78, 72, 90, 18, 69, 22, 32],
        [99, 82, 23, 32, 78, 32, 90,  9],
        [53, 73, 14, 78, 21, 84, 85, 50],
        [91, 22, 57, 73, 67, 30, 57, 64],
        [88, 49, 34, 93, 50, 98, 40, 29],
        [67, 96, 99, 74, 83, 14,  3, 23],
        [71, 16, 83, 87, 36,  5, 10, 95],
        [44, 15, 31, 77, 98, 31, 14, 52],
        [33, 76, 13, 49, 85, 57, 49, 21],
        [ 5, 14, 30, 79, 64, 83, 87, 68],
        [51, 71, 77, 14, 37, 46, 72, 27],
        [52, 66, 36, 41, 53, 39, 53, 70],
        [12, 85, 56, 50, 21, 32, 60, 76],
        [81, 35, 54, 54, 24, 75, 26, 74],
        [10,  6, 72, 66,  2, 98, 57, 57],
        [64, 77, 45, 39, 45, 36, 34, 67],
        [13, 79, 96, 86,  3, 26, 13, 89],
        [78, 42, 89, 68, 49, 90, 21, 13],
        [83, 47, 36, 75, 16, 32, 14, 45],
        [95,  3, 40, 60, 64, 35, 14, 23],
        [78, 77, 77, 45,  1, 92,  8,  7]]])>

tf.where(cond,x,y)

cond是條件,如果條件為真,則取x的值,反之取y值

cond,x和y的型別都是相同的,均為張量

不指定x,y時,會直接返回真的值的索引

用於提取一個很大的矩陣中,大於零的索引,先和0進行比較然後用tf.where

a = tf.ones([2,2])
b = tf.zeros([2,2])
a,b
(<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
 array([[1., 1.],
        [1., 1.]], dtype=float32)>,
 <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
 array([[0., 0.],
        [0., 0.]], dtype=float32)>)
cond = tf.constant([
                    [True,False],
                    [False,True]
                    ])
cond.shape
TensorShape([2, 2])
tf.where(cond,a,b)
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 0.],
       [0., 1.]], dtype=float32)>
tf.where(cond)
<tf.Tensor: shape=(2, 2), dtype=int64, numpy=
array([[0, 0],
       [1, 1]], dtype=int64)>
#### 用處提取一個很大的矩陣中,大於零的索引,先和0進行比較然後用tf.where
x = tf.random.normal([3,3])
x
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[-0.52138734, -1.0115561 , -0.3441449 ],
       [ 1.1962733 ,  0.27069718,  0.22945389],
       [ 1.4716065 ,  2.7272916 ,  1.391508  ]], dtype=float32)>
mask = x > 0
mask
<tf.Tensor: shape=(3, 3), dtype=bool, numpy=
array([[False, False, False],
       [ True,  True,  True],
       [ True,  True,  True]])>
indices = tf.where(mask)
tf.gather_nd(x,indices)
<tf.Tensor: shape=(6,), dtype=float32, numpy=
array([1.1962733 , 0.27069718, 0.22945389, 1.4716065 , 2.7272916 ,
       1.391508  ], dtype=float32)>

tf.scatter_nd(indices,updatas,shape)

可以高效的重新整理張量的資料

indices = tf.constant([ [4],[3],[1],[7] ])
updatas = tf.constant([ 1.,2.,3.,4. ])
tf.scatter_nd(indices,updatas,[8])
<tf.Tensor: shape=(8,), dtype=float32, numpy=array([0., 3., 0., 2., 1., 0., 0., 4.], dtype=float32)>

tf.meshgrid

用於生成二維網路取樣座標,方便三維視覺化

$ z = x^2 + y^2 $

for x in range(-8,8):  #注意python自帶的range無法將間隔設定為浮點數,但是numpy的range可以
    print(x)
-8
-7
-6
-5
-4
-3
-2
-1
0
1
2
3
4
5
6
7
points =[]
for x in np.arange(-8,8,0.1):  #注意python自帶的range無法將間隔設定為浮點數,但是numpy的range可以
    for y in np.arange(-8,8,0.1):
        z = x**2 + y**2
        points.append([x,y,z])
len(points)
25600
# 但是如果像上面這樣效率會很低,所以用meshgrid直接生成網格點
x = tf.linspace(-8.,8.,160)
y = tf.linspace(-8.,8.,160)
x,y = tf.meshgrid(x,y)
x.shape,y.shape
(TensorShape([160, 160]), TensorShape([160, 160]))
z = x**2 + y**2
z.shape
TensorShape([160, 160])
fig = plt.figure()
ax = Axes3D(fig)
ax.contour3D(x,y,z,500)
plt.show()

在這裡插入圖片描述