TF2.1進階-tf.booleanmask_where_scatter_nd_mashgrid
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
print(tf.__version__)
2.1.0
# 4 35 8 4個班級,每個班有35名學生,每名學生有8個科目要上
x = tf.random.uniform([4,35,8],maxval = 100,dtype=tf.int32)
x
<tf.Tensor: shape=(4, 35, 8), dtype=int32, numpy=
array([[[64, 17, 0, ..., 95, 78, 21],
[92, 60, 30, ..., 0, 86, 71],
[40, 23, 69, ..., 90, 74, 31],
...,
[15, 0, 52, ..., 68, 17, 6],
[48, 54, 28, ..., 99, 24, 2],
[ 2, 0, 17, ..., 18, 74, 0]],
[[42, 88, 70, ..., 83, 74, 64],
[79, 92, 32, ..., 93, 79, 28],
[20, 76, 26, ..., 54, 96, 80],
...,
[81, 66, 18, ..., 84, 88, 11],
[30, 72, 38, ..., 8, 94, 43],
[24, 81, 19, ..., 53, 55, 14]],
[[43, 44, 26, ..., 11, 10, 31],
[44, 98, 23, ..., 53, 88, 81],
[70, 29, 73, ..., 62, 77, 1],
...,
[73, 57, 26, ..., 7, 56, 74],
[44, 25, 3, ..., 73, 61, 63],
[55, 42, 50, ..., 35, 19, 34]],
[[58, 17, 39, ..., 6, 22, 27],
[86, 49, 51, ..., 50, 25, 19],
[75, 60, 4, ..., 88, 67, 49],
...,
[83, 47, 36, ..., 32, 14, 45],
[95, 3, 40, ..., 35, 14, 23],
[78, 77, 77, ..., 92, 8, 7]]])>
使用掩碼的方法採取資料
tf.boolean_mask()
掩碼的長度,需要和維度相等,四個維度上採取長度需要有四個布林值
# 掩碼的長度,需要和維度相等,四個維度上採取長度需要有四個布林值
mask = [True,False,False,True]
tf.boolean_mask(x,mask,axis=0)
<tf.Tensor: shape=(2, 35, 8), dtype=int32, numpy=
array([[[64, 17, 0, 20, 72, 95, 78, 21],
[92, 60, 30, 36, 33, 0, 86, 71],
[40, 23, 69, 50, 47, 90, 74, 31],
[ 8, 17, 44, 91, 27, 86, 26, 27],
[15, 16, 6, 29, 98, 5, 17, 65],
[ 4, 86, 32, 26, 42, 23, 76, 52],
[75, 82, 53, 98, 56, 32, 92, 34],
[81, 55, 17, 47, 80, 64, 40, 79],
[73, 40, 7, 85, 60, 67, 92, 81],
[32, 58, 3, 48, 49, 85, 4, 51],
[12, 55, 55, 12, 43, 30, 10, 11],
[25, 38, 19, 97, 99, 32, 28, 36],
[30, 71, 76, 75, 77, 66, 30, 70],
[31, 4, 1, 82, 13, 5, 16, 62],
[90, 43, 5, 36, 8, 54, 88, 45],
[96, 49, 80, 14, 23, 79, 55, 17],
[ 2, 96, 55, 8, 89, 87, 56, 37],
[98, 42, 50, 71, 55, 14, 52, 57],
[56, 18, 61, 52, 39, 69, 51, 28],
[37, 34, 38, 52, 28, 51, 74, 64],
[74, 56, 47, 84, 92, 34, 89, 65],
[73, 87, 2, 43, 7, 43, 63, 57],
[26, 83, 94, 42, 22, 95, 81, 57],
[89, 54, 16, 3, 99, 49, 40, 1],
[72, 76, 5, 99, 40, 20, 11, 91],
[62, 98, 57, 64, 42, 13, 25, 39],
[26, 53, 42, 12, 73, 60, 38, 35],
[44, 10, 52, 23, 32, 23, 7, 42],
[36, 10, 84, 37, 40, 53, 90, 61],
[85, 48, 76, 43, 96, 87, 4, 75],
[51, 54, 44, 2, 96, 69, 78, 27],
[99, 31, 97, 10, 6, 21, 41, 88],
[15, 0, 52, 36, 25, 68, 17, 6],
[48, 54, 28, 63, 70, 99, 24, 2],
[ 2, 0, 17, 51, 82, 18, 74, 0]],
[[58, 17, 39, 15, 24, 6, 22, 27],
[86, 49, 51, 12, 96, 50, 25, 19],
[75, 60, 4, 64, 46, 88, 67, 49],
[23, 87, 42, 85, 86, 31, 48, 76],
[18, 69, 24, 52, 19, 7, 86, 14],
[11, 74, 64, 57, 52, 29, 57, 35],
[25, 14, 1, 25, 3, 38, 38, 44],
[33, 16, 70, 53, 14, 1, 99, 63],
[70, 44, 36, 83, 69, 99, 90, 72],
[34, 90, 87, 95, 20, 38, 48, 55],
[51, 77, 32, 30, 35, 11, 96, 52],
[88, 28, 90, 96, 39, 44, 45, 59],
[17, 79, 34, 18, 10, 86, 31, 70],
[54, 48, 84, 53, 61, 81, 36, 22],
[85, 78, 72, 90, 18, 69, 22, 32],
[99, 82, 23, 32, 78, 32, 90, 9],
[53, 73, 14, 78, 21, 84, 85, 50],
[91, 22, 57, 73, 67, 30, 57, 64],
[88, 49, 34, 93, 50, 98, 40, 29],
[67, 96, 99, 74, 83, 14, 3, 23],
[71, 16, 83, 87, 36, 5, 10, 95],
[44, 15, 31, 77, 98, 31, 14, 52],
[33, 76, 13, 49, 85, 57, 49, 21],
[ 5, 14, 30, 79, 64, 83, 87, 68],
[51, 71, 77, 14, 37, 46, 72, 27],
[52, 66, 36, 41, 53, 39, 53, 70],
[12, 85, 56, 50, 21, 32, 60, 76],
[81, 35, 54, 54, 24, 75, 26, 74],
[10, 6, 72, 66, 2, 98, 57, 57],
[64, 77, 45, 39, 45, 36, 34, 67],
[13, 79, 96, 86, 3, 26, 13, 89],
[78, 42, 89, 68, 49, 90, 21, 13],
[83, 47, 36, 75, 16, 32, 14, 45],
[95, 3, 40, 60, 64, 35, 14, 23],
[78, 77, 77, 45, 1, 92, 8, 7]]])>
tf.where(cond,x,y)
cond是條件,如果條件為真,則取x的值,反之取y值
cond,x和y的型別都是相同的,均為張量
不指定x,y時,會直接返回真的值的索引
用於提取一個很大的矩陣中,大於零的索引,先和0進行比較然後用tf.where
a = tf.ones([2,2])
b = tf.zeros([2,2])
a,b
(<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 1.],
[1., 1.]], dtype=float32)>,
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 0.],
[0., 0.]], dtype=float32)>)
cond = tf.constant([
[True,False],
[False,True]
])
cond.shape
TensorShape([2, 2])
tf.where(cond,a,b)
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1., 0.],
[0., 1.]], dtype=float32)>
tf.where(cond)
<tf.Tensor: shape=(2, 2), dtype=int64, numpy=
array([[0, 0],
[1, 1]], dtype=int64)>
#### 用處提取一個很大的矩陣中,大於零的索引,先和0進行比較然後用tf.where
x = tf.random.normal([3,3])
x
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[-0.52138734, -1.0115561 , -0.3441449 ],
[ 1.1962733 , 0.27069718, 0.22945389],
[ 1.4716065 , 2.7272916 , 1.391508 ]], dtype=float32)>
mask = x > 0
mask
<tf.Tensor: shape=(3, 3), dtype=bool, numpy=
array([[False, False, False],
[ True, True, True],
[ True, True, True]])>
indices = tf.where(mask)
tf.gather_nd(x,indices)
<tf.Tensor: shape=(6,), dtype=float32, numpy=
array([1.1962733 , 0.27069718, 0.22945389, 1.4716065 , 2.7272916 ,
1.391508 ], dtype=float32)>
tf.scatter_nd(indices,updatas,shape)
可以高效的重新整理張量的資料
indices = tf.constant([ [4],[3],[1],[7] ])
updatas = tf.constant([ 1.,2.,3.,4. ])
tf.scatter_nd(indices,updatas,[8])
<tf.Tensor: shape=(8,), dtype=float32, numpy=array([0., 3., 0., 2., 1., 0., 0., 4.], dtype=float32)>
tf.meshgrid
用於生成二維網路取樣座標,方便三維視覺化
$ z = x^2 + y^2 $
for x in range(-8,8): #注意python自帶的range無法將間隔設定為浮點數,但是numpy的range可以
print(x)
-8
-7
-6
-5
-4
-3
-2
-1
0
1
2
3
4
5
6
7
points =[]
for x in np.arange(-8,8,0.1): #注意python自帶的range無法將間隔設定為浮點數,但是numpy的range可以
for y in np.arange(-8,8,0.1):
z = x**2 + y**2
points.append([x,y,z])
len(points)
25600
# 但是如果像上面這樣效率會很低,所以用meshgrid直接生成網格點
x = tf.linspace(-8.,8.,160)
y = tf.linspace(-8.,8.,160)
x,y = tf.meshgrid(x,y)
x.shape,y.shape
(TensorShape([160, 160]), TensorShape([160, 160]))
z = x**2 + y**2
z.shape
TensorShape([160, 160])
fig = plt.figure()
ax = Axes3D(fig)
ax.contour3D(x,y,z,500)
plt.show()
相關文章
- 高階前端進階(三)前端
- 高階前端進階(七)前端
- 高階前端進階(五)前端
- vue進階Vue
- SQL進階SQL
- protobuf進階
- gRPC進階RPC
- HBase進階
- 06進階
- Redux 進階Redux
- JavaScript進階JavaScript
- ElasticSearch 進階Elasticsearch
- Python進階Python
- Vuejs進階知識(十八)【component 進階知識】VueJS
- React 進階(三) 高階元件React元件
- Typescript 高階語法進階TypeScript
- 高階前端進階系列 - webview前端WebView
- Java小白進階筆記(5)-進階物件導向Java筆記物件
- sqlmap 進階 (一)SQL
- RocketMQ進階技巧MQ
- Python進階 — matplotlibPython
- Flutter redux 進階FlutterRedux
- React 進階一React
- 進階筆記筆記
- Redux進階(一)Redux
- CSS進階 --- BFCCSS
- vue進階二Vue
- vue進階一Vue
- Linux進階命令Linux
- RSA進階(一)
- SQLMAP進階使用SQL
- Python進階之道Python
- Sanic 路由進階路由
- DP進階合集
- Python進階 -- matplotlibPython
- printf 進階用法
- Spring Security進階Spring
- Arthas 進階教程