tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解

輕墨發表於2018-08-28

tensorflow中的資料集類Dataset有一個shuffle方法,用來打亂資料集中資料順序,訓練時非常常用。其中shuffle方法有一個引數buffer_size,非常令人費解,文件的解釋如下:

buffer_size: A tf.int64 scalar tf.Tensor, representing the number of elements from this dataset from which the new dataset will sample.

你看懂了嗎?反正我反覆看了這說明十幾次,仍然不知所指。

首先,Dataset會取所有資料的前buffer_size資料項,填充 buffer,如下圖

tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解

然後,從buffer中隨機選擇一條資料輸出,比如這裡隨機選中了item 7,那麼bufferitem 7對應的位置就空出來了

tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解

然後,從Dataset中順序選擇最新的一條資料填充到buffer中,這裡是item 10

tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解

然後在從Buffer中隨機選擇下一條資料輸出。

需要說明的是,這裡的資料項item,並不只是單單一條真實資料,如果有batch size,則一條資料項item包含了batch size條真實資料。

shuffle是防止資料過擬合的重要手段,然而不當的buffer size,會導致shuffle無意義,具體可以參考這篇Importance of buffer_size in shuffle()

tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解

相關文章