這是對 Pytorch 官網的 Tutorial 教程的中文翻譯。
資料並不總是以訓練機器學習演算法所需的最終處理形式出現,我們使用轉換來對資料執行一些操作並使其適合訓練。
所有 TorchVision 資料集都有兩個引數:用於修改特徵的 transform
和用於修改標籤的 target_transform
。接受包含轉換邏輯的可呼叫物件。 torchvision.transforms 模組提供了幾種開箱即用的常用轉換。
FashionMNIST 資料集的特徵採用 PIL 影像格式,標籤為整數。對於訓練,我們需要將特徵作為歸一化張量,將標籤作為獨熱編碼張量。為了進行這些轉換,我們使用 ToTensor
和 Lambda
。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
輸出:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/26421880 [00:00<?, ?it/s]
0%| | 65536/26421880 [00:00<01:12, 362470.31it/s]
1%| | 229376/26421880 [00:00<00:38, 681259.72it/s]
4%|3 | 950272/26421880 [00:00<00:11, 2185553.59it/s]
15%|#4 | 3833856/26421880 [00:00<00:02, 7599317.20it/s]
34%|###4 | 9109504/26421880 [00:00<00:00, 18310296.11it/s]
46%|####5 | 12091392/26421880 [00:00<00:00, 17936658.84it/s]
68%|######7 | 17924096/26421880 [00:01<00:00, 22974578.28it/s]
89%|########9 | 23592960/26421880 [00:01<00:00, 25758355.11it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 18198564.66it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 325487.35it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/4422102 [00:00<?, ?it/s]
1%|1 | 65536/4422102 [00:00<00:12, 362947.95it/s]
5%|5 | 229376/4422102 [00:00<00:06, 682324.89it/s]
21%|##1 | 950272/4422102 [00:00<00:01, 2189897.25it/s]
87%|########6 | 3833856/4422102 [00:00<00:00, 7611069.08it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 6093636.48it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 39985698.13it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
ToTensor()
ToTensor 將 PIL 影像或 NumPy ndarray
轉換為 FloatTensor
。並將影像畫素值縮放到 [0., 1.]
範圍內。
Lambda 轉換
Lambda 轉換應用任何使用者定義的 lambda 函式。在這裡,我們定義一個函式將整數轉換為 one-hot 編碼張量。它首先建立一個大小為 10 的零張量(資料集中的標籤數量)並呼叫 scatter_
,它在標籤 y
給出的索引上分配 value=1
。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))