PyTorch 介紹 | DATSETS & DATALOADERS

Deep_RS發表於2022-01-28

用於處理資料樣本的程式碼可能會變得凌亂且難以維護;理想情況下,我們希望資料集程式碼和模型訓練程式碼解耦(分離),以獲得更好的可讀性和模組性。PyTorch提供了兩個data primitives:torch.utils.data.DataLoadertorch.utils.data.Dataset,允許你使用預載入的datasets和你自己的data。Dataset 儲存樣本及其對應的標籤,DataLoaderDataset 包裝了一個迭代器,以便訪問樣本。

PyTorch庫提供了一些預載入的資料集(如FashionMNIST),它們是 torch.utils.data.Dataset 的子類,特定的資料對應特定的實現函式。它們可以用來原型化和基準化你的模型。你可以在這裡檢視它們:Image Datasets, Text Datasets, and Audio Datasets

載入資料集

這是一個怎樣從TorchVision載入Fashion-MNIST資料集的例子。Fashion-MNIST來自於Zalando的文章,由60000張訓練樣本和10000張測試樣本組成。每一個樣本包含一個28x28
的灰度圖片和對應的10類中的1個類的標籤。

我們用以下引數載入FashionMNIST Dataset

  • root 是訓練/測試資料的儲存路徑
  • train 指定是訓練集還是測試集
  • download=True 如果 root 中沒有,則從網上下載
  • transformtarget_transform 指定樣本的變換
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root='data',
    train=False,
    download=True
    transform=ToTensor()
)

輸出:

點選檢視程式碼
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

迭代和資料集視覺化

我們可以像list一樣索引Datasets:training_data[index]。使用 matplotlib 視覺化一些訓練集的樣本。

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    # torch.squeeze():刪除維數為1的維度
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

建立自定義資料集

一個自定義的資料集類必須實現三個函式:init,len,getitem。檢視下面的實現過程,FashionMNIST圖片儲存在 img_dir,它們的標籤分別儲存在一個CSV檔案(逗號分隔值檔案) annotations_file 中。

下一節,我們將分解每個函式做了什麼的。

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # 利用pandas讀取csv並轉換為DataFrame
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

init

一旦例項化Datase物件,函式__init__ 就會立即執行:初始化包含圖片的目錄,標籤檔案,以及兩個轉換(下一節有更詳細的介紹)

labels.csv類似這樣:

tshirt1.jpg, 0
tshirt2.jpg, 0
...
anleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    # 這裡指定了列名
    self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'labels'])
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

len

__len__ 函式返回資料集的樣本數

例如:

def __len__(self):
    return len(self.img_labels)

getitem

__getitem__函式載入和返回資料集中給定索引 idx 的樣本。根據索引,它獲得了硬碟上圖片的位置,利用 read_image 轉換為tensor,在 self.img_labels ,從csv中檢索相應的標籤,並呼叫轉換函式(如果可用),返回一個包含圖片和對應標籤張量的元組。

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

利用DataLoader為訓練準備你的資料

Dataset只能同時檢索一個樣本的資料特徵和標籤。當訓練模型時,通常需要傳遞“minibatches”樣本,每一個epoch重複打亂資料減少過擬合,並使用Python的 multiprocessing 加速資料檢索。

DataLoader 是一個迭代器。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

通過DataLoader迭代

我們已經將該資料集載入到 DataLoader,根據需要可以對資料集進行迭代。每次迭代返回一個 train_featurestrain_labels 的batch(分別包含 batch_size=64的特徵和標籤)。因為我們指定了 shuffle=True, 在我們迭代完所有的batch之後,資料就會被打亂(為了對資料載入順序進行更細緻的控制,參閱Samplers

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

輸出:

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 7

延伸閱讀