Tensorflow的資料輸入模組tf.data模組
通過列表建立dataset
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
for ele in dataset:
print(ele.numpy())
print(dataset)
通過字典建立dataset
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'c': [12, 23, 23, 34]})
for ele in dataset:
print(ele)
通過take方法取資料
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
print(dataset.take(1))
通過迭代取資料
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
print(next(iter(dataset.take(1))))
資料亂序
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
for ele in dataset:
print(ele.numpy())
dataset=dataset.shuffle(3)
for ele in dataset:
print(ele.numpy())
當資料集重複
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
dataset=dataset.repeat(count=3)
for ele in dataset:
print(ele.numpy())
取出資料
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
dataset=dataset.batch(3)
for ele in dataset:
print(ele.numpy())
對每一個元素進行操作
# -*-coding= utf-8 -*-
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 4, 5, 6, 7])
dataset=dataset.map(tf.square)
for ele in dataset:
print(ele.numpy())
例項
# -*-coding= utf-8 -*-
import tensorflow as tf
from fashionMnistUtil import get_data
(train_image, train_label), (test_image, test_label) = get_data()
train_image = train_image / 255
test_image = test_image / 255
ds_train_img = tf.data.Dataset.from_tensor_slices(train_image)
ds_test = tf.data.Dataset.from_tensor_slices((test_image, test_label))
ds_train_label = tf.data.Dataset.from_tensor_slices(train_label)
ds_train = tf.data.Dataset.zip((ds_train_img, ds_train_label))
ds_train = ds_train.shuffle(10000).repeat().batch(64)
ds_test = ds_test.batch(64)
model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(ds_train,
epochs=5,
steps_per_epoch=train_image.shape[0] // 64,
validation_data=ds_test,
validation_steps=10000 // 64)
相關文章
- TensorFlow Data模組
- appium 資料引數化 登入模組APP
- 資料分析---pandas模組
- swoole 模組的載入
- 以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018優化Go
- 開源公開課丨 ChunJun 資料傳輸模組介紹
- 資料的輸入輸出
- js模組化之自定義模組(頁面模組化載入)JS
- 資料分析---matplotlib模組
- 第217篇:Cameralink轉光纖資料傳輸板- Base Camera link輸入轉光纖輸出模組
- TypeScript入門-模組TypeScript
- 模組載入器
- swiper 模組載入
- Python入門(二十六):檔案模組(os模組與shutil模組)Python
- SAP PM 入門系列14 – PM模組與其它模組的整合
- JavaScript 模組的迴圈載入JavaScript
- drozer模組的編寫及模組動態載入問題研究
- py模組匯入示例
- 模組匯入小結
- JavaScript 模組載入特性JavaScript
- Webpack模組載入器Web
- php載入memcache模組PHP
- CO模組成本物件控制的主資料物件
- Linux核心模組程式設計-將/proc作為輸入(轉)Linux程式設計
- 03 資料輸入-輸出
- Angular入門到精通系列教程(11)- 模組(NgModule),延遲載入模組Angular
- 用 shelve 模組來存資料
- django哪個模組配置資料庫Django資料庫
- iOS資料上報模組封裝方案iOS封裝
- SOFARegistry 原始碼|資料同步模組解析原始碼
- 使用pprint模組格式化資料
- tensorflow2.0 自定義類模組列印問題
- 開源交流丨批流一體資料整合框架ChunJun資料傳輸模組詳解分享框架
- tensorflow載入資料的三種方式
- DBI 資料庫模組剖析:Perl DBI 資料庫通訊模組規範,工作原理和例項資料庫
- 簡單的資料輸入
- Python資料的輸入與輸出Python
- TensorFlow 入門(MNIST資料集)