image-classification-dataset

Mr小明同学發表於2024-06-23
import torchtext
from torchvision import transforms
from torch.utils import data
from d2l import torch as d2l
import torchvision

trans = transforms.ToTensor()

fashion_mnist_train = torchvision.datasets.FashionMNIST("../data" , 
                                                        train = True , 
                                                        transform = trans , 
                                                        download = True)
fashion_mnist_test = torchvision.datasets.FashionMNIST("../data", 
                                                       train = True , 
                                                       transform = trans , 
                                                       download = True)

def get_fashion_mnist_label(label):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [ text_labels[i] for i in label ]

batch_size = 50
dataloader_train = data.DataLoader(fashion_mnist_train , 
                                   batch_size =batch_size  , 
                                   shuffle= True , 
                                   num_workers = 0 )

dataloader_test = data.DataLoader(fashion_mnist_test ,
                                  batch_size = batch_size ,
                                  shuffle = False ,
                                  num_workers = 0
                                 )  
a = [1,2]
a.insert(0,5)
print(a)
[5, 1, 2]
def get_fashion_mnist_dataloader(batch_size , resize = None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0,transforms.Resize(resize))
        
    trans  = transforms.Compose(trans)
    train_dataset = torchvision.datasets.FashionMNIST("../data",
                                                      train = True,
                                                      transform = trans,
                                                      download = True
                                                     )
    test_dataset = torchvision.datasets.FashionMNIST("../data",
                                                     train = False,
                                                     transform = trans,
                                                     download = True
                                                    )
    return ( 
            data.DataLoader( train_dataset , batch_size = batch_size , shuffle = True , num_workers = 0 ) , 
            data.DataLoader( test_dataset , batch_size = batch_size , shuffle = False , num_workers = 0 )
    )  
batch_size= 50
train_dataloader , test_dataloader = get_fashion_mnist_dataloader(batch_size , resize= (60,60))

for X , Y  in train_dataloader:
    print(X.shape , Y.shape)
    break
for X, Y in test_dataloader:
    print(X.shape , Y.shape)
    break
torch.Size([50, 1, 60, 60]) torch.Size([50])
torch.Size([50, 1, 60, 60]) torch.Size([50])

重點函式

  • transforms.ToTensor() 不可忘記括號
  • 同時 結合transforms.Compose() 傳入一個列表 順序對資料進行處理 作為合成的trans
  • 圖片轉換相關函式 在 torchvision.transforms
  • 常用dataset在 torchvision.dataset裡
  • 讀取常用資料集流程
    • 應用 torchvision.transforms 封裝好trans
    • 應用 torchvision.dataset 讀取出相應資料集
      • 配置引數:
      • data:資料位置,
      • train:bool 是否讀取訓練集false則代表測試集 ,
      • transform:格式轉換。
      • download:是否下載資料集
    • 應用torch.utils.data.Dataloader 將dataset 封裝為一個可迭代物件
      • 配置引數:
      • dataset:對應dataset
      • batch_size:批次大小
      • shuffle:是否打亂