Pytorch Dataset入門

gy77發表於2024-04-16

Dataset入門

Pytorch Dataset code:torch/utils/data/dataset.py#L17

Pytorch Dataset tutorial: tutorials/beginner/basics/data_tutorial.html

理論:

PyTorch中的Dataset是一個抽象類,用來表示資料集的介面,所有其他資料集都需要繼承這個類,並且覆寫以下三個方法:

  1. __init__:初始化資料集的一些配置,例如載入所有的資料標籤。

  2. __len__:以便len(dataset)可以返回資料集的大小,例如n。如果n小於資料集長度,則只會取前n個的資料。

  3. __getitem__:輸入是資料的索引,以便可以使用dataset[i]來獲取第i個樣本,資料增強一般會在這裡做。

程式碼:

下面是一個自定義的Dataset樣例(不可執行):

import cv2
import json
import torch.utils.Dataset as Dataset

class CustomDataset(Dataset):
    def __init__(self, imgs_path, labels_path, img_transform=None, label_transform=None):
        self.imgs_path = imgs_path  # 輸入影像的路徑,list
        self.labels_path = labels_path  # 輸入影像對應的標籤路徑,list
        self.img_transform = img_transform  # 影像的資料增強
        self.label_transform = label_transform  # 標籤的資料增強

    def __len__(self):
        return len(self.imgs_path)  # 返回資料集的長度

    def __getitem__(self, idx):
        img_path = self.imgs_path[idx]
        label_path = self.labels_path[idx]
        img = cv2.imread(img_path)  # 讀取影像
        label = json.load(open(label_path))  # 讀取標籤
        if self.img_transform:  # 影像的資料增強
            img = self.img_transform(img)
        if self.label_transform:  # 標籤的資料增強
            label = self.label_transform(label)
        return img, label  # 返回影像和標籤,用於訓練
Pytorch Dataset入門

總結:

值得注意的是,Dataset只負責資料的載入和預處理,對於如何訓練資料(例如:是否進行shuffle,是否進行並行加速等)這部分的邏輯是由DataLoader實現的。通常情況下,我們會將DatasetDataLoader一起使用。

另外,PyTorch還提供了一些常用的資料集,如:ImageFolderCIFAR10MNIST等,這些資料集都是繼承Dataset類,同時在init方法中進行資料的下載,以及在getitem方法中進行資料的載入和預處理。

Dataset是單執行緒讀取資料,每次只能讀取一個樣本,不能一次性讀取一個mini-batch的資料。

Dataset的主要特性包含:

  • 抽象介面:PyTorch透過定義一個抽象Dataset類,讓使用者可以使用統一的方式來載入各種不同的資料,提供了很好的擴充套件性。

  • 懶載入:實際的資料載入並不發生在構造資料集例項時,而是發生在用到這些資料時,這樣可以提高記憶體利用率,並且可以實現對大規模資料的處理。

  • 預處理:Dataset的一個重要應用就是資料預處理,你可以在getitem函式中進行任何你的資料預處理過程。


嗨,歡迎大家關注我的公眾號《CV之路》,一起討論問題,一起學習進步~

相關文章