Tensorflow的資料輸入模組tf.data模組
通過列表建立dataset
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
for ele in dataset:
print(ele.numpy())
print(dataset)
通過字典建立dataset
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'c': [12, 23, 23, 34]})
for ele in dataset:
print(ele)
通過take方法取資料
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
print(dataset.take(1))
通過迭代取資料
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
print(next(iter(dataset.take(1))))
資料亂序
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
for ele in dataset:
print(ele.numpy())
dataset=dataset.shuffle(3)
for ele in dataset:
print(ele.numpy())
當資料集重複
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
dataset=dataset.repeat(count=3)
for ele in dataset:
print(ele.numpy())
取出資料
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
dataset=dataset.batch(3)
for ele in dataset:
print(ele.numpy())
對每一個元素進行操作
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 4, 5, 6, 7])
dataset=dataset.map(tf.square)
for ele in dataset:
print(ele.numpy())
例項
# -*-coding= utf-8 -*-
import tensorflow as tf
from fashionMnistUtil import get_data
(train_image, train_label), (test_image, test_label) = get_data()
train_image = train_image / 255
test_image = test_image / 255
ds_train_img = tf.data.Dataset.from_tensor_slices(train_image)
ds_test = tf.data.Dataset.from_tensor_slices((test_image, test_label))
ds_train_label = tf.data.Dataset.from_tensor_slices(train_label)
ds_train = tf.data.Dataset.zip((ds_train_img, ds_train_label))
ds_train = ds_train.shuffle(10000).repeat().batch(64)
ds_test = ds_test.batch(64)
model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(ds_train,
epochs=5,
steps_per_epoch=train_image.shape[0] // 64,
validation_data=ds_test,
validation_steps=10000 // 64)
相關文章
- TensorFlow Data模組
- 序列化模組,隨機數模組,os模組,sys模組,hashlib模組隨機
- appium 資料引數化 登入模組APP
- Python入門(二十六):檔案模組(os模組與shutil模組)Python
- 資料分析---matplotlib模組
- 資料分析---pandas模組
- swoole 模組的載入
- python 模組:itsdangerous 模組Python
- path模組 fs模組
- Python模組:time模組Python
- TypeScript入門-模組TypeScript
- swiper 模組載入
- day18:json模組&time模組&zipfile模組JSON
- SAP PM 入門系列14 – PM模組與其它模組的整合
- Python模組之urllib模組Python
- python模組之collections模組Python
- CommonJS模組 和 ECMAScript模組JS
- drozer模組的編寫及模組動態載入問題研究
- 序列化模組,subprocess模組,re模組,常用正則
- 模組匯入小結
- Python入門—time模組Python
- py模組匯入示例
- Python 模組匯入方式Python
- 聊天模組及分享模組分享
- [Python模組學習] glob模組Python
- 模組學習之hashlib模組
- 模組學習之logging模組
- 用 shelve 模組來存資料
- Angular入門到精通系列教程(11)- 模組(NgModule),延遲載入模組Angular
- Python常用模組(random隨機模組&json序列化模組)Pythonrandom隨機JSON
- 讀懂CommonJS的模組載入JS
- python之匯入模組的方法Python
- Python 模組的載入順序Python
- 模組
- 請問在Home或者Admin模組下如何進入Addons模組
- 第217篇:Cameralink轉光纖資料傳輸板- Base Camera link輸入轉光纖輸出模組
- 開源公開課丨 ChunJun 資料傳輸模組介紹
- ECDSA—模乘模組