在做實驗時,我們常常會使用用開源的資料集進行測試。而Pytorch中內建了許多資料集,這些資料集我們常常使用DataLoader
類進行載入。
如下面這個我們使用DataLoader
類載入torch.vision
中的FashionMNIST
資料集。
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
我們接下來定義Dataloader
物件用於載入這兩個資料集:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
那麼這個train_dataloader
究竟是什麼型別呢?
print(type(train_dataloader)) # <class 'torch.utils.data.dataloader.DataLoader'>
我們可以將先其轉換為迭代器型別。
print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>
然後再使用next(iter(train_dataloader))
從迭代器裡取資料,如下所示:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
可以看到我們成功獲取了資料集中第一張圖片的資訊,控制檯列印:
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 2
圖片視覺化顯示如下:
不過有讀者可能就會產生疑問,很多時候我們並沒有將DataLoader
型別強制轉換成迭代器型別呀,大多數時候我們會寫如下程式碼:
for train_features, train_labels in train_dataloader:
print(train_features.shape) # torch.Size([64, 1, 28, 28])
print(train_features[0].shape) # torch.Size([1, 28, 28])
print(train_features[0].squeeze().shape) # torch.Size([28, 28])
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
可以看到,該程式碼也能夠正常迭代訓練資料,前三個樣本的控制檯列印輸出為:
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 7
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 4
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 1
那麼為什麼我們這裡沒有顯式將Dataloader
轉換為迭代器型別呢,其實是Python語言for迴圈的一種機制,一旦我們用for ... in ...
句式來迭代一個物件,那麼Python直譯器就會偷偷地自動幫我們建立好迭代器,也就是說
for train_features, train_labels in train_dataloader:
實際上等同於
for train_features, train_labels in iter(train_dataloader):
更進一步,這實際上等同於
train_iterator = iter(train_dataloader)
try:
while True:
train_features, train_labels = next(train_iterator)
except StopIteration:
pass
推而廣之,我們在用Python迭代直接迭代列表時:
for x in [1, 2, 3, 4]:
其實Python直譯器已經為我們隱式轉換為迭代器了:
list_iterator = iter([1, 2, 3, 4])
try:
while True:
x = next(list_iterator)
except StopIteration:
pass
參考文獻
- [1] https://pytorch.org/
- [2] Martelli A, Ravenscroft A, Ascher D. Python cookbook[M]. " O'Reilly Media, Inc.", 2005.