深度學習系列教程(六)tf.data API 使用方法介紹

五嶽凌峰發表於2018-12-05

轉載自https://zhuanlan.zhihu.com/p/32649553。謝謝作者辛苦整理。若侵權,告知即刪。

傾心之作!天學網AI學院名師團隊“玩轉TensorFlow與深度學習模型”系列文字教程,本週帶來tf.data API 使用方法介紹!

該教程通過知識點講解+答疑指導相結合的方式,讓大家循序漸進的瞭解深度學習模型並通過實操演示掌握相關框架及TensorFlow工具使用。

大家在學習和實操過程中,有任何疑問都可以通過學院微信交流群進行提問,有導師和助教、大牛等為您解答哦。(入群方式在文末

第六篇的教程主要內容:TensorFlow 資料匯入 (tf.data API 使用介紹)。

tf.data 簡介

以往的TensorFLow模型資料的匯入方法可以分為兩個主要方法,一種是使用feed_dict另外一種是使用TensorFlow中的Queues。前者使用起來比較靈活,可以利用Python處理各種輸入資料,劣勢也比較明顯,就是程式執行效率較低;後面一種方法的效率較高,但是使用起來較為複雜,靈活性較差。

Dataset作為新的API,比以上兩種方法的速度都快,並且使用難度要遠遠低於使用Queuestf.data中包含了兩個用於TensorFLow程式的介面:DatasetIterator

Dataset(資料集) API 在 TensorFlow 1.4版本中已經從tf.contrib.data遷移到了tf.data之中,增加了對於Python的生成器的支援,官方強烈建議使用Dataset API 為 TensorFlow模型建立輸入管道,原因如下:

  • 與舊 API(feed_dict 或佇列式管道)相比,Dataset API 可以提供更多功能。
  • Dataset API 的效能更高。
  • Dataset API 更簡潔,更易於使用。

 

將來 TensorFlow 團隊將會將開發中心放在Dataset API而不是舊的API上。

Dataset

Dataset表示一個元素的集合,可以看作函數語言程式設計中的 lazy list, 元素是tensor tuple。建立Dataset的方式可以分為兩種,分別是:

  • Source
  • Apply transformation

Source

這裡 source 指的是從tf.Tensor物件建立Dataset,常見的方法又如下幾種:

tf.data.Dataset.from_tensors((features, labels))
tf.data.Dataset.from_tensor_slices((features, labels))
tf.data.TextLineDataset(filenames)
tf.data.TFRecordDataset(filenames)

作用分別為:從一個tensor tuple建立一個單元素的dataset;從一個tensor tuple建立一個包含多個元素的dataset;讀取一個檔名列表,將每個檔案中的每一行作為一個元素,構成一個dataset;讀取硬碟中的TFRecord格式檔案,構造dataset。

Apply transformation

第二種方法就是通過轉化已有的dataset來得到新的dataset,TensorFLow tf.data.Dataset支援很多中變換,在這裡介紹常見的幾種:

dataset.map(lambda x: tf.decode_jpeg(x))
dataset.repeat(NUM_EPOCHS)
dataset.batch(BATCH_SIZE)

以上三種方式分別表示了:使用map對dataset中的每個元素進行處理,這裡的例子是對圖片資料進行解碼;將dataset重複一定數目的次數用於多個epoch的訓練;將原來的dataset中的元素按照某個數量疊在一起,生成mini batch。

TensorFlow 1.4 版本中還允許使用者通過Python的生成器構造dataset,如:

def generator():
  while True:
    yield ...

dataset = tf.data.Dataset.from_generator(generator, tf.int32)

將以上程式碼組合起來,我們可以得到一個常用的程式碼片段:

# 從一個檔名列表讀取 TFRecord 構成 dataset
dataset = TFRecordDataset(["file1.tfrecord", "file2.tfrecord"])
# 處理 string,將 string 轉化為 tf.Tensor 物件
dataset = dataset.map(lambda record: tf.parse_single_example(record))
# buffer 大小設定為 10000,打亂 dataset
dataset = dataset.shuffle(10000)
# dataset 將被用來訓練 100 個 epoch
dataset = dataset.repeat(100)
# 設定 batch size 為 128
dataset = dataset.batch(128)

Iterator

定義好了資料集以後可以通過Iterator介面來訪問資料集中的tensor tuple,iterator保持了資料在資料集中的位置,提供了訪問資料集中資料的方法。

可以通過呼叫 dataset 的 make iterator 方法來構建 iterator。

API 支援以下四種 iterator,複雜程度遞增:

  • one-shot
  • initializable
  • reinitializable
  • feedable

one-shot

one-shot iterator 誰最簡單的一種 iterator,僅支援對整個資料集訪問一遍,不需要顯式的初始化。one-shot iterator 不支引數化。以下程式碼使用tf.data.Dataset.range生成資料集,作用與 python 中的 range 類似。

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)
  assert i == value

initializable

Initializable iterator 要求在使用之前顯式的通過呼叫iterator.initializer操作初始化,這使得在定義資料集時可以結合tf.placeholder傳入引數,如:

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
  value = sess.run(next_element)
  assert i == value

# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
  value = sess.run(next_element)
  assert i == value

reinitializable

reinitializable iterator 可以被不同的 dataset 物件初始化,比如對於訓練集進行了shuffle的操作,對於驗證集則沒有處理,通常這種情況會使用兩個具有相同結構的dataset物件,如:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.Iterator.from_structure(training_dataset.output_types,
                                   training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(100):
    sess.run(next_element)

  # Initialize an iterator over the validation dataset.
  sess.run(validation_init_op)
  for _ in range(50):
    sess.run(next_element)

feedable

feedable iterator 可以通過和tf.placeholder結合在一起,同通過feed_dict機制來選擇在每次呼叫tf.Session.run的時候選擇哪種Iterator。它提供了與 reinitilizable iterator 類似的功能,並且在切換資料集的時候不需要在開始的時候初始化iterator,還是上面的例子,通過tf.data.Iterator.from_string_handle來定義一個 feedable iterator,達到切換資料集的目的:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Loop forever, alternating between training and validation.
while True:
  # Run 200 steps using the training dataset. Note that the training dataset is
  # infinite, and we resume from where we left off in the previous `while` loop
  # iteration.
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  # Run one pass over the validation dataset.
  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})
程式碼示例

這裡舉一個讀取、解碼圖片,並且將圖片的大小進行調整的例子:

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

相關文章