Dataset入門
Pytorch Dataset code:
Pytorch Dataset tutorial:
理論:
PyTorch中的Dataset
是一個抽象類,用來表示資料集的介面,所有其他資料集都需要繼承這個類,並且覆寫以下三個方法:
-
__init__:初始化資料集的一些配置,例如載入所有的資料標籤。
-
__len__:以便
len(dataset)
可以返回資料集的大小,例如n。如果n小於資料集長度,則只會取前n個的資料。 -
__getitem__:輸入是資料的索引,以便可以使用
dataset[i]
來獲取第i個樣本,資料增強一般會在這裡做。
程式碼:
下面是一個自定義的Dataset樣例(不可執行):
總結:
值得注意的是,Dataset
只負責資料的載入和預處理,對於如何訓練資料(例如:是否進行shuffle,是否進行並行加速等)這部分的邏輯是由DataLoader
實現的。通常情況下,我們會將Dataset
和DataLoader
一起使用。
另外,PyTorch還提供了一些常用的資料集,如:ImageFolder
,CIFAR10
,MNIST
等,這些資料集都是繼承Dataset
類,同時在init
方法中進行資料的下載,以及在getitem
方法中進行資料的載入和預處理。
Dataset是單執行緒讀取資料,每次只能讀取一個樣本,不能一次性讀取一個mini-batch的資料。
Dataset的主要特性包含:
-
抽象介面:PyTorch透過定義一個抽象
Dataset
類,讓使用者可以使用統一的方式來載入各種不同的資料,提供了很好的擴充套件性。 -
懶載入:實際的資料載入並不發生在構造資料集例項時,而是發生在用到這些資料時,這樣可以提高記憶體利用率,並且可以實現對大規模資料的處理。
-
預處理:
Dataset
的一個重要應用就是資料預處理,你可以在getitem
函式中進行任何你的資料預處理過程。
嗨,歡迎大家關注我的公眾號《CV之路》,一起討論問題,一起學習進步~