加速訓練之並行化 tf.data.Dataset 生成器

魚與魚發表於2022-06-11

在處理大規模資料時,資料無法全部載入記憶體,我們通常用兩個選項

  • 使用tfrecords
  • 使用 tf.data.Dataset.from_generator()

tfrecords的並行化使用前文已經有過介紹,這裡不再贅述。如果我們不想生成tfrecord中間檔案,那麼生成器就是你所需要的。

本文主要記錄針對 from_generator()的並行化方法,在 tf.data 中,並行化主要通過 mapnum_parallel_calls 實現,但是對一些場景,我們的generator()中有一些處理邏輯,是無法直接並行化的,最簡單的方法就是將generator()中的邏輯抽出來,使用map實現。

tf.data.Dataset generator 並行

generator()中的複雜邏輯,我們對其進行簡化,即僅在生成器中做一些下標取值的型別操作,將generator()中處理部分使用py_function 包裹(wrapped) ,然後呼叫map處理。

def func(i):
    i = i.numpy() # Decoding from the EagerTensor object
    x, y = your_processing_function(training_set[i])
    return x, y

z = list(range(len(training_set))) # The index generator

dataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)

dataset = dataset.map(lambda i: tf.py_function(func=func, 
                                               inp=[i], 
                                               Tout=[tf.uint8,
                                                     tf.float32]
                                               ), 
                      num_parallel_calls=tf.data.AUTOTUNE)

由於隱式推斷的原因,有時tensor的輸出shape是未知的,需要額外處理

dataset = dataset.batch(8)
def _fixup_shape(x, y):
    x.set_shape([None, None, None, nb_channels]) # n, h, w, c
    y.set_shape([None, nb_classes]) # n, nb_classes
    return x, y
dataset = dataset.map(_fixup_shape)

tf.Tensor與tf.EagerTensor

為什麼需要 tf.py_function,先來看下tf.Tensortf.EagerTensor

EagerTensor是實時的,可以在任何時候獲取到它的值,即通過numpy獲取

Tensor是非實時的,它是靜態圖中的元件,只有當喂入資料、運算完成才能獲得該Tensor的值,

map中對映的函式運算,而僅僅是告訴dataset,你每一次拿出來的樣本時要先進行一遍function運算之後才使用的,所以function的呼叫是在每次迭代dataset的時候才呼叫的,屬於靜態圖邏輯

tensorflow.python.framework.ops.EagerTensor
tensorflow.python.framework.ops.Tensor

tf.py_function在這裡起了什麼作用?

Wraps a python function into a TensorFlow op that executes it eagerly.

剛才說到map資料靜態圖邏輯,預設引數都是Tensor。而 使用tf.py_function()包裝後,引數就變成了EagerTensor。

references

【1】https://medium.com/@acordier/tf-data-dataset-generators-with-parallelization-the-easy-way-b5c5f7d2a18

【2】https://blog.csdn.net/qq_27825451/article/details/105247211

【3】https://www.tensorflow.org/guide/data_performance#parallelizing_data_extraction

相關文章