Tensorflow的資料輸入模組tf.data模組

希望以後住大房子娶漂亮老婆發表於2020-11-21

通過列表建立dataset

# -*-coding= utf-8 -*-
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
for ele in dataset:
    print(ele.numpy())
print(dataset)

通過字典建立dataset

# -*-coding= utf-8 -*-
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'c': [12, 23, 23, 34]})
for ele in dataset:
    print(ele)

通過take方法取資料

# -*-coding= utf-8 -*-
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
print(dataset.take(1))

通過迭代取資料

# -*-coding= utf-8 -*-
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
print(next(iter(dataset.take(1))))

資料亂序

# -*-coding= utf-8 -*-
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
for ele in dataset:
    print(ele.numpy())
dataset=dataset.shuffle(3)
for ele in dataset:
    print(ele.numpy())

當資料集重複

# -*-coding= utf-8 -*-
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
dataset=dataset.repeat(count=3)
for ele in dataset:
    print(ele.numpy())

取出資料

# -*-coding= utf-8 -*-
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [4, 5], [6, 7]])
dataset=dataset.batch(3)
for ele in dataset:
    print(ele.numpy())

對每一個元素進行操作

# -*-coding= utf-8 -*-
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 4, 5, 6, 7])
dataset=dataset.map(tf.square)
for ele in dataset:
    print(ele.numpy())

例項

# -*-coding= utf-8 -*-
import tensorflow as tf

from fashionMnistUtil import get_data

(train_image, train_label), (test_image, test_label) = get_data()
train_image = train_image / 255
test_image = test_image / 255
ds_train_img = tf.data.Dataset.from_tensor_slices(train_image)
ds_test = tf.data.Dataset.from_tensor_slices((test_image, test_label))
ds_train_label = tf.data.Dataset.from_tensor_slices(train_label)
ds_train = tf.data.Dataset.zip((ds_train_img, ds_train_label))
ds_train = ds_train.shuffle(10000).repeat().batch(64)
ds_test = ds_test.batch(64)
model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),
                             tf.keras.layers.Dense(128, activation='relu'),
                             tf.keras.layers.Dense(10, activation='softmax')])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(ds_train,
          epochs=5, 
          steps_per_epoch=train_image.shape[0] // 64,
          validation_data=ds_test,
          validation_steps=10000 // 64)

相關文章