Pytorch資料載入與使用

云岛夜川川發表於2024-06-17

前言

在訓練的時候通常使用Dataset來處理資料集。

Dataset的作用

提供一個方式獲取資料內容和標籤(label)。

實戰

from torch.utils.data import Dataset

from PIL import Image
import os

class get_data(Dataset):

    def __init__(self,root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.img_dir = os.path.join(root_dir,label_dir)
        self.img_list = os.listdir(self.img_dir)


    def __getitem__(self, indx):

        img_path = os.path.join(self.img_dir,self.img_list[indx])
        img_label = self.label_dir
        img_data = Image.open(img_path)
        return img_data,img_label

    def __len__(self):
        return len(self.img_list)

root_dir = "C:\\Users\\Traveler\\Pictures"
label_dir = "Screenshots"

test = get_data(root_dir,label_dir)

img , label = test[1]

# img.show()
print(label)
print(len(test))

相關文章