以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018

Jason Wu發表於2018-09-21

Google 開發者大會 (Google Developer Days,簡稱 GDD) 是展示 Google 最新開發者產品和平臺的全球盛會,旨在幫助你快速開發優質應用,發展和留住活躍使用者群,充分利用各種工具獲得更多收益。2018 Google 開發者大會於 9 月 20 日和 21 日於上海舉辦。?Google 開發者大會 2018 掘金專題

GDD 2018 第二天的 9 月 21 日 ,陳爽(Google Brain 軟體工程師)為我們帶來了《以 tf.data 優化訓練資料》,講解如何使用 tf.data 為各類模型打造高效能的 TensorFlow 輸入渠道,本文將摘錄演講技術乾貨。

資料輸入管道

  • 大多人將時間和金錢花在神經網路架構上,資料輸入容易被忽略
  • 沒有好的資料輸入管道,GPU 再強速度也不會顯著提高
  • 目標:高效丶靈活丶易用

ETL 系統

以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018

  1. 提取資料(Extract):將訓練資料從存取器(硬碟丶雲端等)提取
  2. 轉換資料(Transform):將資料轉換為模型可讀取的資料,同時進行資料清洗等預處理
  3. 裝載資料(Load):將處理好的資料裝載至加速器

tf.data:為機器學習設計的資料輸入系統

以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018

圖中程式碼分別對應 ETL 系統的三個步驟,使用 tf.data 即可輕鬆實現。

tf.data 優化手段:以上圖程式碼為例

  1. 多執行緒處理(使用 num_parallel_reads
files = tf.data.Dataset.list_files("training-*-of-1024.tfrecord")
dataset = tf.data.TFRecordDataset(files, num_parallel_reads=32)
複製程式碼
  1. 合併轉換步驟(如 shuffle_and_repaeat, map_and_batch
dataset = dataset.apply(tf.contrib.data.shuffle_and_repaeat(10000, NUM_EPOCHS))
dataset = dataset.apply(tf.contrib.data.map_and_batch(lambda x: ..., BATCH_SIZE))
複製程式碼
  1. 流水線化(使用 prefetch_to_device
dataset = dataset.apply(tf.contrib.data.prefetch_to_device("/gpu:0"))
複製程式碼

以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018

最終程式碼如下圖所示,更多優化手段可以參考 tf.data 效能指南

以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018

tf.data 的靈活性

支援函數語言程式設計

以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018

如上圖,可以用自定義的 map_fn 處理 TensorFlow 或相容的函式,同時支援 AutoGraph 處理過的函式。

支援不同語言與資料型別

  • 使用 Dataset.form_generator() 支援 Python 程式碼生成 Dataset
  • 使用 DatasetOpKernel 和 tf.load_op_library 支援自定義 C++ 資料處理程式碼

如下圖,使用 Python 自帶的 urllib 獲取伺服器資料,存入 dataset:

以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018

支援多種資料來源

如普通檔案系統丶GCP 雲儲存丶其他雲儲存丶SQL 資料庫等。

讀取 Google 雲儲存的 TFRecord 檔案示例:

files = tf.contrib.data.TFRecordDataset(
  "gs://path/to/file.tfrecord", num_parallel_reads=32)
複製程式碼

使用自訂 SQL 資料庫示例:

files = tf.contrib.data.SqlDataset(
  "sqllite", "/foo/db.sqlite", "SELECT name, age FROM people", 
  (tf.string, tf.int32))
複製程式碼

tf.data 的易用性

在 Eager 執行模式下,可以直接使用 Python for 迴圈:

tf.enable_eager_execution()
for batch in dataset:
    train_model(batch)
複製程式碼

為 TF Example 或 CSV 提供現有高效配方

以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018

上圖可以簡單替換為一個函式:

dataset = tf.contrib.data.make_batched_features_dataset(
  "training-*-of-1024.tfrecord",
  BATCH_SIZE, features, num_epochs=NUM_EPOCHS)
複製程式碼

使用 CSV 資料集的情境:

dataset = tf.contrib.data.make_csv_dataset(
  "*.csv", BATCH_SIZE, num_epochs=NUM_EPOCHS)
複製程式碼

使用 AUTOTUNE 自動調節管道

可以簡單的使用 AUTOTUNE 找到 prefetching 的最佳引數:

dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
複製程式碼

支援 Keras 和 Estimators 相互相容

對於 Keras,可以將 dataset 直接傳遞使用;對於 Estimators 訓練函式,將 dataset 包裝至輸入函式並返回即可,如下示例:

def input_fn():
    dataset = tf.contrib.data.make_csv_dataset(
      "*.csv", BATCH_SIZE, num_epochs=NUM_EPOCHS)
    return dataset
   
tf.estimator.Estimator(model_fn=train_model).train(input_fn=input_fn)
複製程式碼

實際運用經驗

  • 原始 tf.data 資料輸入程式碼: ~150 影象 / 秒
  • 管道化的 tf.data 資料輸入程式碼: ~1,750 影象 / 秒 => 12倍的效能!
  • Cloud TPU 上使用 tf.data: ~4,100 影象 / 秒
  • Cloud TPU Pod 上使用 tf.data: ~219,000 影象 / 秒

結論

本場演講介紹了 tf.data 這個兼具高效丶靈活與易用的 API,同時瞭解如何運用管道化及其他優化手段來增進運算效能,以及許多可能未曾發現的實用函式。

資源

以 tf.data 優化訓練資料輸入管道 丨 Google 開發者大會 2018

閱讀更多 Google 開發者大會 2018 技術乾貨

相關文章