tensorflow dataset API

zhangztSky發表於2020-12-18
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras
#
# (x,y), (x_test, y_test) = keras.datasets.cifar100.load_data()
# y = tf.squeeze(y, axis=1)
# y_test = tf.squeeze(y_test, axis=1)
# print(x.shape, y.shape, x_test.shape, y_test.shape)
# train_db = tf.data.Dataset.from_tensor_slices((x,y))

dataset = tf.data.Dataset.from_tensor_slices(np.arange(5))
dataset=dataset.repeat(3)
for i in dataset:
    print(i)
'''
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
'''
print("*"*30)
dataset=dataset.shuffle(100).batch(4)
print(type(dataset))
for i in dataset:
    print(i)
'''
dataset=dataset.shuffle(100).batch(4)
tf.Tensor([3 0 3 1], shape=(4,), dtype=int32)
tf.Tensor([3 1 2 4], shape=(4,), dtype=int32)
tf.Tensor([0 4 4 2], shape=(4,), dtype=int32)
tf.Tensor([0 1 2], shape=(3,), dtype=int32)


dataset=dataset.batch(4).shuffle(100)
tf.Tensor([3 4 0 1], shape=(4,), dtype=int32)
tf.Tensor([2 3 4], shape=(3,), dtype=int32)
tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 0 1 2], shape=(4,), dtype=int32)
'''


x = np.array([[1, 2], [3, 4], [5, 6]])
y = np.array(['cat', 'dog', 'fox'])

dataset3 = tf.data.Dataset.from_tensor_slices((x, y))
print(dataset3)




dataset4 = tf.data.Dataset.from_tensor_slices({"feature": x,
                                               "label": y})
for item in dataset4:
    print(item["feature"].numpy(), item["label"].numpy())

'''
[1 2] b'cat'
[3 4] b'dog'
[5 6] b'fox'
'''



a = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
b=a.interleave(lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
            cycle_length=2, block_length=4)
for item in b:
    print(item.numpy(),end=', ')

'''
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 4, 4, 5, 5, 5, 5, 5, 5,
'''

參考:interleave

相關文章