TensorFlow Data模組

Alexanderhaha發表於2019-08-13

模組作用

tf.data api用於建立訓練前匯入資料和資料處理的pipeline,使得處理大規模資料,不同資料格式和複雜資料處理變的容易。

基本抽象

提供了兩種基本抽象:DatasetIterator

Dataset

表示元素序列集合,每個元素包含一個或者多個Tensor物件,每個元素是一個樣本。有兩種方式可以建立Dataset。

  1. 從源資料建立,比如:Dataset.from_tensor_slices()
  2. 通過資料處理轉換建立,比如 Dataset.map()/batch()

Iterator

用於從Dataset獲取資料給訓練作為輸入,作為輸入管道和訓練程式碼的介面。最簡單的例子是“one-shot iterator”,和一個dataset繫結並返回它的元素一次。通過調整初始化引數,可以實現一些複雜場景的需求,比如迴圈迭代訓練集的元素。

基本機制

基本使用的流程如下:

1. 建立一個dataset
2. 做一些資料處理(map, batch)
3. 建立一個iterator提供給訓練使用

資料結構

每個元素都有相同的資料結構,每個元素包含一個或者多個tensor,可以元組表示或者巢狀表示。每個tensor包含型別tf.DType和維度形狀tf.TensorShape。每個tensor可以有名稱,通過collections.namedtuple或者字典實現。

Dataset的資料處理介面能夠支援任意資料結構。

建立Iterator

有了dataset以後,下一步就是建立一個iterator提供訪問資料元素的介面,現在tf.data api支援4中iterator來實現不同程度的複雜場景。

  • one-shot,
  • initializable,
  • reinitializable
  • feedable.

one-shot

最簡單,最常用,就是一次遍歷所有元素,目前(2019/08)也是estimator唯一支援的iterator。

initializable

需要顯示執行初始化操作,可以通過tf.placeholder引數化dataset。

reinitializalbe

可以從多個dataset多次初始化,當然需要相同資料結構。每次切換資料集,使用前需要初始化,

feedable

可以從多個dataset多次初始化,初始化完畢後,可以通過tf.placeholder隨時切換資料集。

獲取iterator的值

通過session.run(iterator.get_next())完成,若沒有資料了則報異常:tf.errors.OutOfRangeError。

iterator.get_next()返回的是tf.Tensor物件,需要在session中執行才會有值。

儲存iterator狀態

save and restore the current state of the iterator (and, effectively, the whole input pipeline). 能將整個輸入pipeline儲存並恢復(原理是什麼?

saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()

讀取輸入資料

直接讀取numpy的資料得到array,呼叫Dataset.from_tensor_slices

讀取TFRecord data: dataset = tf.data.TFRecordDataset(filenames)

通過檔案的方式,可以支援不能夠全部匯入記憶體中的場景。tf.data.TFRcordDataset將讀取資料流作為整個輸入流的一部分。

同時,檔名可以通過tf.placeholder傳入。

這裡所有的dataset的返回,也都是tensorflow中的operation,需要在session中計算得到值。

解析tf.Example

推薦情形下,輸入需要從TFRecord檔案格式讀取TF Example的protocol buffer messages,每個TF Example包含一個或者多個features,輸入管道需要將其轉換為tensor。

典型列子

# Transforms a scalar string `example_proto` into a pair of a scalar string and# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto): 
    features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),              "label": tf.FixedLenFeature((), tf.int64, default_value=0)} 
    parsed_features = tf.parse_single_example(example_proto, features)  return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

對於影象,可以通過tf.image.decode_jpeg等方式提取和resize。

任意的python邏輯處理資料情形

在文字當中,比如有時候需要呼叫其他python庫,比如jieba,這時候輸入處理需要通過t'f.py_func()opeation完成。

Batching dataset elements

簡單batch

呼叫Dataset.batch()實現,所有元素包含同樣的資料結構。

帶Padding的batch

當所有元素包含不同長度的資料結構時,可以通過padding使得單個batch資料長度一致。序列模型情形下。制定padding的shape為可變長度,比如(None,)來完成,會自動padding到單個batch的最大長度。當然,可以重寫padding的值或者padding的長度。

訓練流程

多批次遍歷

tf.data提供了兩種遍歷資料的方式

想要遍歷資料多次,可以使用Dataset.repeat(n)操作完成。遍歷n次對應多少次epochs。若不提供引數,則將無限迴圈。同時,將不會給出一個批次結束的訊號。

若想在每個批次結束後做處理,則需要在外圍加迴圈,然後通過重複初始化迭代器完成,捕捉tf.errors.OutOfRangeError異常。

隨機shuffle

dataset.shuffle()維持一個快取區,隨機取下一個資料。

高階API

tf.train.MonitoredTrainingSession 介面簡化了分散式Tensorflow的很多方面。MonitoredTrainingSession 用tf.error.OutOfRangeError來獲取訓練是否結束,所以推薦採用one-shot 迭代器。

用dataset作為estimator的輸入函式時,直接將dataset返回,estimator會自動建立迭代器並初始化。

相關文章