模組作用
tf.data api用於建立訓練前匯入資料和資料處理的pipeline,使得處理大規模資料,不同資料格式和複雜資料處理變的容易。
基本抽象
提供了兩種基本抽象:Dataset和Iterator
Dataset
表示元素序列集合,每個元素包含一個或者多個Tensor物件,每個元素是一個樣本。有兩種方式可以建立Dataset。
- 從源資料建立,比如:Dataset.from_tensor_slices()
- 通過資料處理轉換建立,比如 Dataset.map()/batch()
Iterator
用於從Dataset獲取資料給訓練作為輸入,作為輸入管道和訓練程式碼的介面。最簡單的例子是“one-shot iterator”,和一個dataset繫結並返回它的元素一次。通過調整初始化引數,可以實現一些複雜場景的需求,比如迴圈迭代訓練集的元素。
基本機制
基本使用的流程如下:
1. 建立一個dataset
2. 做一些資料處理(map, batch)
3. 建立一個iterator提供給訓練使用
資料結構
每個元素都有相同的資料結構,每個元素包含一個或者多個tensor,可以元組表示或者巢狀表示。每個tensor包含型別tf.DType和維度形狀tf.TensorShape。每個tensor可以有名稱,通過collections.namedtuple或者字典實現。
Dataset的資料處理介面能夠支援任意資料結構。
建立Iterator
有了dataset以後,下一步就是建立一個iterator提供訪問資料元素的介面,現在tf.data api支援4中iterator來實現不同程度的複雜場景。
- one-shot,
- initializable,
- reinitializable
- feedable.
one-shot
最簡單,最常用,就是一次遍歷所有元素,目前(2019/08)也是estimator唯一支援的iterator。
initializable
需要顯示執行初始化操作,可以通過tf.placeholder引數化dataset。
reinitializalbe
可以從多個dataset多次初始化,當然需要相同資料結構。每次切換資料集,使用前需要初始化,
feedable
可以從多個dataset多次初始化,初始化完畢後,可以通過tf.placeholder隨時切換資料集。
獲取iterator的值
通過session.run(iterator.get_next())完成,若沒有資料了則報異常:tf.errors.OutOfRangeError。
iterator.get_next()返回的是tf.Tensor物件,需要在session中執行才會有值。
儲存iterator狀態
save and restore the current state of the iterator (and, effectively, the whole input pipeline). 能將整個輸入pipeline儲存並恢復(原理是什麼?)
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
讀取輸入資料
直接讀取numpy的資料得到array,呼叫Dataset.from_tensor_slices
讀取TFRecord data: dataset = tf.data.TFRecordDataset(filenames)
通過檔案的方式,可以支援不能夠全部匯入記憶體中的場景。tf.data.TFRcordDataset將讀取資料流作為整個輸入流的一部分。
同時,檔名可以通過tf.placeholder傳入。
這裡所有的dataset的返回,也都是tensorflow中的operation,需要在session中計算得到值。
解析tf.Example
推薦情形下,輸入需要從TFRecord檔案格式讀取TF Example的protocol buffer messages,每個TF Example包含一個或者多個features,輸入管道需要將其轉換為tensor。
典型列子
# Transforms a scalar string `example_proto` into a pair of a scalar string and# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""), "label": tf.FixedLenFeature((), tf.int64, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features) return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
對於影象,可以通過tf.image.decode_jpeg等方式提取和resize。
任意的python邏輯處理資料情形
在文字當中,比如有時候需要呼叫其他python庫,比如jieba,這時候輸入處理需要通過t'f.py_func()opeation完成。
Batching dataset elements
簡單batch
呼叫Dataset.batch()實現,所有元素包含同樣的資料結構。
帶Padding的batch
當所有元素包含不同長度的資料結構時,可以通過padding使得單個batch資料長度一致。序列模型情形下。制定padding的shape為可變長度,比如(None,)來完成,會自動padding到單個batch的最大長度。當然,可以重寫padding的值或者padding的長度。
訓練流程
多批次遍歷
tf.data提供了兩種遍歷資料的方式
想要遍歷資料多次,可以使用Dataset.repeat(n)操作完成。遍歷n次對應多少次epochs。若不提供引數,則將無限迴圈。同時,將不會給出一個批次結束的訊號。
若想在每個批次結束後做處理,則需要在外圍加迴圈,然後通過重複初始化迭代器完成,捕捉tf.errors.OutOfRangeError異常。
隨機shuffle
dataset.shuffle()維持一個快取區,隨機取下一個資料。
高階API
tf.train.MonitoredTrainingSession 介面簡化了分散式Tensorflow的很多方面。MonitoredTrainingSession 用tf.error.OutOfRangeError來獲取訓練是否結束,所以推薦採用one-shot 迭代器。
用dataset作為estimator的輸入函式時,直接將dataset返回,estimator會自動建立迭代器並初始化。