import torchtext
from torchvision import transforms
from torch.utils import data
from d2l import torch as d2l
import torchvision
trans = transforms.ToTensor()
fashion_mnist_train = torchvision.datasets.FashionMNIST("../data" ,
train = True ,
transform = trans ,
download = True)
fashion_mnist_test = torchvision.datasets.FashionMNIST("../data",
train = True ,
transform = trans ,
download = True)
def get_fashion_mnist_label(label):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [ text_labels[i] for i in label ]
batch_size = 50
dataloader_train = data.DataLoader(fashion_mnist_train ,
batch_size =batch_size ,
shuffle= True ,
num_workers = 0 )
dataloader_test = data.DataLoader(fashion_mnist_test ,
batch_size = batch_size ,
shuffle = False ,
num_workers = 0
)
a = [1,2]
a.insert(0,5)
print(a)
[5, 1, 2]
def get_fashion_mnist_dataloader(batch_size , resize = None):
trans = [transforms.ToTensor()]
if resize:
trans.insert(0,transforms.Resize(resize))
trans = transforms.Compose(trans)
train_dataset = torchvision.datasets.FashionMNIST("../data",
train = True,
transform = trans,
download = True
)
test_dataset = torchvision.datasets.FashionMNIST("../data",
train = False,
transform = trans,
download = True
)
return (
data.DataLoader( train_dataset , batch_size = batch_size , shuffle = True , num_workers = 0 ) ,
data.DataLoader( test_dataset , batch_size = batch_size , shuffle = False , num_workers = 0 )
)
batch_size= 50
train_dataloader , test_dataloader = get_fashion_mnist_dataloader(batch_size , resize= (60,60))
for X , Y in train_dataloader:
print(X.shape , Y.shape)
break
for X, Y in test_dataloader:
print(X.shape , Y.shape)
break
torch.Size([50, 1, 60, 60]) torch.Size([50])
torch.Size([50, 1, 60, 60]) torch.Size([50])
重點函式
- transforms.ToTensor() 不可忘記括號
- 同時 結合transforms.Compose() 傳入一個列表 順序對資料進行處理 作為合成的trans
- 圖片轉換相關函式 在 torchvision.transforms
- 常用dataset在 torchvision.dataset裡
- 讀取常用資料集流程
- 應用 torchvision.transforms 封裝好trans
- 應用 torchvision.dataset 讀取出相應資料集
- 配置引數:
- data:資料位置,
- train:bool 是否讀取訓練集false則代表測試集 ,
- transform:格式轉換。
- download:是否下載資料集
- 應用torch.utils.data.Dataloader 將dataset 封裝為一個可迭代物件
- 配置引數:
- dataset:對應dataset
- batch_size:批次大小
- shuffle:是否打亂