轉自:https://www.cnblogs.com/miraclepbc/p/14367560.html
構建路徑集和標籤集
取出所有路徑
import glob
all_imgs_path = glob.glob(r"E:\datasets2\29-42\29-42\dataset2\dataset2\*.jpg")
獲得所有標籤
species = ['cloudy', 'rain', 'shine', 'sunrise']
all_labels = []
for img in all_imgs_path:
for i, c in enumerate(species):
if c in img:
all_labels.append(i)
定義資料集類
# 必須建立 __getitem__, __len__, __init__
class Mydataset(data.Dataset):
def __init__(self, img_paths, labels, transform):
self.imgs = img_paths
self.labels = labels
self.transforms = transform
def __getitem__(self, index):
img = self.imgs[index]
label = self.labels[index]
pil_img = Image.open(img)
data = self.transforms(pil_img)
return data, label
def __len__(self):
return len(self.imgs)
- 基本屬性是:資料集裡面的影像是誰,相應的標籤是誰,變換方式有什麼
- getitem是索引方法
- len是返回資料集長度
劃分訓練集和測試集
這裡需要將所有路徑進行亂序,再將標籤相應的亂序。取出前80%為訓練集,其他為測試集
index = np.random.permutation(len(all_imgs_path))
all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]
s = int(len(all_imgs_path) * 0.8)
構建訓練集和測試集
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor()
])
train_ds = Mydataset(all_imgs_path[:s], all_labels[:s], transform)
test_ds = Mydataset(all_imgs_path[s:], all_labels[s:], transform)
train_dl = data.DataLoader(train_ds, batch_size = 8, shuffle = True)
test_dl = data.DataLoader(test_ds, batch_size = 8)
構建其他資料集
如果需要對剛剛構建的資料集進行一些其他變換
比如:原來是channel, height, width,現在要改成height, width, channel
這時候可以構建一個新的資料集類
class New_dataset(data.Dataset):
def __init__(self, some_ds):
self.ds = some_ds
def __getitem__(self, index):
img, label = self.ds[index]
img = img.permute(1, 2, 0)
return img, label
def __len__(self):
return len(self.ds)
測試一下:
train_new_ds = New_dataset(train_ds)
img, label = train_new_ds[2]
這個時候,img的shape就是(96, 96, 3)