用於處理資料樣本的程式碼可能會變得凌亂且難以維護;理想情況下,我們希望資料集程式碼和模型訓練程式碼解耦(分離),以獲得更好的可讀性和模組性。PyTorch提供了兩個data primitives:torch.utils.data.DataLoader
和 torch.utils.data.Dataset
,允許你使用預載入的datasets和你自己的data。Dataset
儲存樣本及其對應的標籤,DataLoader
給 Dataset
包裝了一個迭代器,以便訪問樣本。
PyTorch庫提供了一些預載入的資料集(如FashionMNIST),它們是 torch.utils.data.Dataset
的子類,特定的資料對應特定的實現函式。它們可以用來原型化和基準化你的模型。你可以在這裡檢視它們:Image Datasets, Text Datasets, and Audio Datasets。
載入資料集
這是一個怎樣從TorchVision載入Fashion-MNIST資料集的例子。Fashion-MNIST來自於Zalando的文章,由60000張訓練樣本和10000張測試樣本組成。每一個樣本包含一個28x28
的灰度圖片和對應的10類中的1個類的標籤。
我們用以下引數載入FashionMNIST Dataset
root
是訓練/測試資料的儲存路徑train
指定是訓練集還是測試集download=True
如果root
中沒有,則從網上下載transform
和target_transform
指定樣本的變換
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root='data',
train=False,
download=True
transform=ToTensor()
)
輸出:
點選檢視程式碼
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
迭代和資料集視覺化
我們可以像list一樣索引Datasets:training_data[index]
。使用 matplotlib
視覺化一些訓練集的樣本。
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
# torch.squeeze():刪除維數為1的維度
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
建立自定義資料集
一個自定義的資料集類必須實現三個函式:init,len,getitem。檢視下面的實現過程,FashionMNIST圖片儲存在 img_dir
,它們的標籤分別儲存在一個CSV檔案(逗號分隔值檔案) annotations_file
中。
下一節,我們將分解每個函式做了什麼的。
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
# 利用pandas讀取csv並轉換為DataFrame
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
init
一旦例項化Datase物件,函式__init__
就會立即執行:初始化包含圖片的目錄,標籤檔案,以及兩個轉換(下一節有更詳細的介紹)
labels.csv類似這樣:
tshirt1.jpg, 0
tshirt2.jpg, 0
...
anleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
# 這裡指定了列名
self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'labels'])
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
len
__len__
函式返回資料集的樣本數
例如:
def __len__(self):
return len(self.img_labels)
getitem
__getitem__
函式載入和返回資料集中給定索引 idx
的樣本。根據索引,它獲得了硬碟上圖片的位置,利用 read_image
轉換為tensor,在 self.img_labels
,從csv中檢索相應的標籤,並呼叫轉換函式(如果可用),返回一個包含圖片和對應標籤張量的元組。
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
利用DataLoader為訓練準備你的資料
Dataset
只能同時檢索一個樣本的資料特徵和標籤。當訓練模型時,通常需要傳遞“minibatches”樣本,每一個epoch重複打亂資料減少過擬合,並使用Python的 multiprocessing
加速資料檢索。
DataLoader
是一個迭代器。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
通過DataLoader迭代
我們已經將該資料集載入到 DataLoader
,根據需要可以對資料集進行迭代。每次迭代返回一個 train_features
和 train_labels
的batch(分別包含 batch_size=64
的特徵和標籤)。因為我們指定了 shuffle=True
, 在我們迭代完所有的batch之後,資料就會被打亂(為了對資料載入順序進行更細緻的控制,參閱Samplers)
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
輸出:
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 7