Pytorch | Tutorial-03 資料轉換

一碗给力嗯發表於2024-03-20

這是對 Pytorch 官網的 Tutorial 教程的中文翻譯。

資料並不總是以訓練機器學習演算法所需的最終處理形式出現,我們使用轉換來對資料執行一些操作並使其適合訓練。

所有 TorchVision 資料集都有兩個引數:用於修改特徵的 transform 和用於修改標籤的 target_transform。接受包含轉換邏輯的可呼叫物件。 torchvision.transforms 模組提供了幾種開箱即用的常用轉換。

FashionMNIST 資料集的特徵採用 PIL 影像格式,標籤為整數。對於訓練,我們需要將特徵作為歸一化張量,將標籤作為獨熱編碼張量。為了進行這些轉換,我們使用 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

輸出:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 65536/26421880 [00:00<01:12, 362470.31it/s]
  1%|          | 229376/26421880 [00:00<00:38, 681259.72it/s]
  4%|3         | 950272/26421880 [00:00<00:11, 2185553.59it/s]
 15%|#4        | 3833856/26421880 [00:00<00:02, 7599317.20it/s]
 34%|###4      | 9109504/26421880 [00:00<00:00, 18310296.11it/s]
 46%|####5     | 12091392/26421880 [00:00<00:00, 17936658.84it/s]
 68%|######7   | 17924096/26421880 [00:01<00:00, 22974578.28it/s]
 89%|########9 | 23592960/26421880 [00:01<00:00, 25758355.11it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 18198564.66it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 325487.35it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|1         | 65536/4422102 [00:00<00:12, 362947.95it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 682324.89it/s]
 21%|##1       | 950272/4422102 [00:00<00:01, 2189897.25it/s]
 87%|########6 | 3833856/4422102 [00:00<00:00, 7611069.08it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 6093636.48it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 39985698.13it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

ToTensor()

ToTensor 將 PIL 影像或 NumPy ndarray 轉換為 FloatTensor 。並將影像畫素值縮放到 [0., 1.] 範圍內。

Lambda 轉換

Lambda 轉換應用任何使用者定義的 lambda 函式。在這裡,我們定義一個函式將整數轉換為 one-hot 編碼張量。它首先建立一個大小為 10 的零張量(資料集中的標籤數量)並呼叫 scatter_ ,它在標籤 y 給出的索引上分配 value=1

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

相關文章