深度學習入門筆記——DataLoader的使用

cyMessi發表於2024-10-29

如何使用資料集DataSet?

在介紹DataLoader之前,需要先了解資料集DataSet的使用。Pytorch中整合了很多已經處理好的資料集,在pytorch的torchvision、torchtext等模組有一些典型的資料集,可以透過配置來下載使用。

以CIFAR10 資料集為例,文件已經描述的很清晰了,其中要注意的就是transform這個引數,可以用來將影像轉換為所需要的格式,就比如這樣,將PIL格式的影像轉化為tensor格式的影像:

# 準備的測試資料集
test_data=torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)

DataLoader是什麼?

我們可以這樣理解:如果Dataset資料集是一個儲存所有資料(影像、音訊)的容器,那麼DataLoader就是另一個具有更好收納功能的容器,其中分隔開來很多小隔間,可以自己設定一個小隔間有多少個資料集的資料來組成,每次將資料放進收納小隔間的時候要不要把源資料集打亂再進行收納等等
也就是說,給定了一個資料集,我們可以決定如何從資料集裡面拿取資料來進行訓練,比如一次拿取多少資料作為一個物件來對資料集進行分割,對資料集進行分割之前要不要打亂資料集等等。DataLoader的結果就是一個對資料集進行分割的大字典列表,列表中的每個物件都是由設定的多少個資料集的物件組合而成的

如何使用DataLoader?

__getitem__方法

首先需要先理解__getitem__方法,__getitem__被稱為魔法方法,在python中定義一個類的時候,如果想要透過鍵來得到類的輸出值,就需要__getitem__方法。所以__getitem__方法的作用就是在呼叫類的時候自動的執行__getitem__方法的內容,得結果並返回

class Fib():                  #定義類Fib
    def __init__(self,start=0,step=1):
        self.step=step
    def __getitem__(self, key): #定性__getitem__函式,key為類Fib的鍵
            a = key+self.step
            return a          #當按照鍵取值時,返回的值為a
 
s=Fib()
s[1]  #返回2 ,因為類有 __getitem__方法,所以可以直接透過鍵來取對應的值

如果沒有__getitem__方法,那麼就無法透過鍵來得到返回值

class Fib():                  #定義類Fib
    def __init__(self,start=0,step=1):
        self.step=step
s=Fib()
s[1] 
返回:TypeError: 'Fib' object does not support indexing

以Pytorch中的CIFAR10資料集為例,可以看到原始碼中的__getitem__方法是這樣的:

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

可以理解為在呼叫類的時候如果輸入index,也就是這個類中的索引/鍵,那麼就可以自動呼叫__getitem__方法得到返回值image和target,其中image就是資料集中的影像,target是標籤類class中的索引,用來指示label是什麼

DataLoader語法

可以在Pytorch的Documents文件中檢視DataLoader的使用方法,一部分截圖如下所示
這裡介紹幾個比較常用的:

  • dataset:就是我們的資料集,我們構建好資料集物件之後傳入即可

  • batch_size:也就是在資料集容器中一次拿取多少資料,然後這一次拿取的資料就作為一個Dataloader中的一個物件。如果還是不太理解,可以看下面的示例:
    第一張圖片是我設定batch_size=4的時候給出的DataLoader中的一個物件,第二張圖片是batch_size=64的時候給出的DataLoader中的一個物件,這樣就好理解了~


    在這裡也可以很清楚的看出來,資料集是10000張影像/物件,設定batch_size=4之後DataLoader中就是2500個物件了

  • shuffle:是否在每次操作的時候打亂資料集,一般選擇為True。更好的理解就是,如果設定為true,那麼第一次建立DataLoader物件中的每個小物件和第二次DataLoader物件中的小物件是不相同的,因為在每次建立DataLoader的時候就會先把資料集打亂,自然得到的DataLoader也就不一樣

  • num_workers:多執行緒進行拿取資料操作,0表示只在主執行緒中操作,一般來講windows系統下用多執行緒會報錯,設定為0即可

  • drop_last:如果拿取資料有餘數,是否保留最後剩下的部分。如果為True的話,最後剩下的部分被丟棄,為false的話最後剩下的部分不會被丟棄

例如在後面的程式碼中,如果我設定drop_last=False,那麼一共有156次資料拿取,並且最後一次剩餘的部分不會被丟棄

如果設定drop_last=True,那麼最後剩餘的部分被丟棄,並且拿取次數也少了一次

使用DataLoader

初步使用的程式碼如下:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 準備的測試資料集
test_data=torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor())
# 載入Dataloader
test_dataloader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=True)
# 載入tensorboard
writer=SummaryWriter("logs")
# tensorboard中的圖片序號
step=0
for data in test_dataloader:
    images,targets=data
    writer.add_images("test_03",images,step)
    step=step+1
writer.close()

然後配合使用tensorboard就可以直觀體會到它的使用方法了~

相關文章