Pytorch系列:(二)資料載入

Q發表於2021-04-24

DataLoader

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,
batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,
drop_last=False,timeout=0,work_init_fn=None)

常用引數說明:

  • dataset: Dataset類 ( 詳見下文資料集構建 ),可以自定義資料集或者讀取pytorch自帶資料集

  • batch_size: 每個batch載入多少個樣本, 預設1

  • shuffle: 是否順序讀取,True表示隨機打亂,預設False

  • sampler:定義從資料集中提取樣本的策略。如果指定,則忽略shuffle引數。

  • batch_sampler: 定義一個按照batch_size大小返回索引的取樣器。取樣器詳見下文Batch_Sampler

  • num_workers: 資料讀取程式數量, 預設0

  • collate_fn: 自定義一個函式,接收一個batch的資料,進行自定義處理,然後返回處理後這個batch的資料。例如改變資料型別:

def my_collate_fn(batch_data):
    x_batch = []
    y_batch = []
    for x,y in batch_data:
        x_batch.append(x.float())
        y_batch.append(y.int())
    return x_batch,y_batch

  • pin_memory:設定pin_memory=True,則意味著生成的Tensor資料最開始是屬於記憶體中的鎖頁記憶體,這樣將記憶體的Tensor轉義到GPU的視訊記憶體就會更快一些。預設為False.

    主機中的記憶體,有兩種,一種是鎖頁,一種是不鎖頁。鎖頁記憶體存放的內容在任何情況下都不會與主機的虛擬記憶體 (硬碟)進行交換,而不鎖頁記憶體在主機記憶體不足時,資料會存放在虛擬記憶體中。注意顯示卡中的視訊記憶體全部都是鎖業記憶體。如果計算機記憶體充足的話,設定為True可以加快資料交換順序。

  • drop_last:預設False, 最後剩餘資料量不夠batch_size時候,是否丟棄。

  • timeout: 設定資料讀取的時間限制,超過限制時間還未完成資料讀取則報錯。數值必須大於等於0

資料集構建

自定義資料集

自定義資料集,需要繼承torch.utils.data.Dataset,然後在__getitem__()中,接受一個索引,返回一個樣本, 基本流程,首先在__init__()載入資料以及做一些處理,在__getitem__()中返回單個資料樣本,在__len__() 中,返回樣本數量

import torch
import torch.utils.data.dataset as Data 

class MyDataset(Data.Dataset):

    def __init__(self):
            self.x = torch.randn((10,20))
            self.y = torch.tensor([1 if i>5 else 0 for i in range(10)],
            dtype=torch.long)
            
    def __getitem__(self,idx):
        return self.x[idx],self.y[idx]
        
    def __len__(self):
        return self.x.__len__()
       

torchvision資料集

pytorch自帶torchvision庫可以幫助我們方便快捷的讀取和載入資料

import torch
from torchvision import datasets, transforms

# 定義一個預處理方法
transform = transforms.Compose([transforms.ToTensor()])
# 載入一個自帶資料集
trainset = datasets.MNIST('/pytorch/MNIST_data/', download=True, train=True, 
transform=transform)

TensorDataset

注意這裡的tensor必須是一維度的資料。

import torch.utils.data as Data
x = torch.tensor([1,2,3,4,5])
y = torch.tensor([0,0,0,1,1])
dataset = Data.TensorDataset(x,y) 

從資料夾中載入資料集

如果想要載入自己的資料集可以這樣,用貓狗資料集舉例,根目錄下 ( "data/train" ),分別放置兩個資料夾,dog和cat,這樣使用ImageFolder函式就可以自動的將貓狗照片自動的按照資料夾定義為貓狗兩個標籤

import torch
from torchvision import datasets, transforms

data_dir = "data/train"
transform = transforms.Compose([transforms.Resize(255),transforms.ToTensor()])

dataset = datasets.ImageFolder(data_dir, transform=transform)


資料集操作

資料拼接

連線不同的資料集以構成更大的新資料集。

class torch.utils.data.ConcatDataset( [datasets, ... ] )

newDataset = torch.utils.data.ConcatDataset([dataset1,dataset2])

資料切分

方法一: class torch.utils.data.Subset(dataset, indices)

取指定一個索引序列對應的子資料集。

from torch.utils.data import Subset

train_set = Subset(dataset,[i for i in range(1,100)]
test_set = Subset(test0_ds,[i for i in range(100,150)]

方法二:torch.utils.data.random_split(dataset, lengths)

from torch.utils.data import random_split
train_set, test_set =  random_split(dataset,[100,50])

取樣器

所有采樣器都在 torch.utils.data 中,取樣器會根據該有的策略返回一組索引,在DataLoader中設定了取樣器之後,會根據索引讀取相應的樣本, 不同取樣器生成的索引不一樣,從而實現不同的取樣目的。

Sampler

所有采樣器的基類,自定義取樣器的時候需要實現 __iter__() 函式

class Sampler(object):
    """ 
    Base class for all Samplers.
    """
    
    def __init__(self, data_source):
        pass
    
    def __iter__(self):
        raise NotImplementedError

RandomSampler

RandomSampler,當DataLoader的shuffle引數為True時,系統會自動呼叫這個取樣器,實現打亂資料。預設的是採用SequentialSampler,它會按順序一個一個進行取樣。

SequentialSampler

按順序取樣,當DataLoader的shuffle引數為False時,使用的就是SequentialSampler。

SubsetRandomSampler

輸入一個列表,按照這個列表取樣。也可以通過這個取樣器來分割資料集。

BatchSampler

引數:sampler, batch_size, drop_last

每此返回batch_size數量的取樣索引,通過設定sampler引數來使用不同的取樣方法。

WeightedRandomSampler

引數:weights, num_samples, replacement

它會根據每個樣本的權重選取資料,在樣本比例不均衡的問題中,可用它來進行重取樣。通過weights 設定樣本權重,權重越大的樣本被選中的概率越大,待選取的樣本數目一般小於全部的樣本數目。num_samples 為返回索引的數量,replacement表示是否是放回抽樣,如果為True,表示可以重複取樣,預設為True

自定義取樣器

整合Sampler類,然後實現__iter__() 方法,比如,下面實現一個SequentialSampler類

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.
    Arguments:
        data_source (Dataset): dataset to sample from
    """
   
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

相關文章