Pytorch資料讀取與預處理實現與探索

頎周發表於2021-03-26

  在煉丹時,資料的讀取與預處理是關鍵一步。不同的模型所需要的資料以及預處理方式各不相同,如果每個輪子都我們自己寫的話,是很浪費時間和精力的。Pytorch幫我們實現了方便的資料讀取與預處理方法,下面記錄兩個DEMO,便於加快以後的程式碼效率。

  根據資料是否一次性讀取完,將DEMO分為:

  1、序列式讀取。也就是一次性讀取完所有需要的資料到記憶體,模型訓練時不會再訪問外存。通常用在記憶體足夠的情況下使用,速度更快。

  2、並行式讀取。也就是邊訓練邊讀取資料。通常用在記憶體不夠的情況下使用,會佔用計算資源,如果分配的好的話,幾乎不損失速度。

  Pytorch官方的資料提取方式儘管方便編碼,但由於它提取資料方式比較死板,會浪費資源,下面對其進行分析。

序列式讀取

DEMO程式碼

import torch 
from torch.utils.data import Dataset,DataLoader 
   
class MyDataSet(Dataset):# ————1————
  def __init__(self):    
    self.data = torch.tensor(range(10)).reshape([5,2])
    self.label = torch.tensor(range(5))

  def __getitem__(self, index):   
    return self.data[index], self.label[index]

  def __len__(self):    
    return len(self.data)
  
my_data_set = MyDataSet()# ————2————
my_data_loader = DataLoader(
  dataset=my_data_set,   # ————3————
  batch_size=2,          # ————4————
  shuffle=True,          # ————5————
  sampler=None,          # ————6————
  batch_sampler=None,    # ————7———— 
  num_workers=0 ,        # ————8———— 
  collate_fn=None,       # ————9———— 
  pin_memory=True,       # ————10———— 
  drop_last=True         # ————11————
)

for i in my_data_loader: # ————12————
  print(i)

  註釋處解釋如下:

  1、重寫資料集類,用於儲存資料。除了 __init__() 外,必須實現 __getitem__() 和 __len__() 兩個方法。前一個方法用於輸出索引對應的資料。後一個方法用於獲取資料集的長度。

  2~5、 2準備好資料集後,傳入DataLoader來迭代生成資料。前三個引數分別是傳入的資料集物件、每次獲取的批量大小、是否打亂資料集輸出。

  6、取樣器,如果定義這個,shuffle只能設定為False。所謂取樣器就是用於生成資料索引的可迭代物件,比如列表。因此,定義了取樣器,取樣都按它來,shuffle再打亂就沒意義了。

  7、批量取樣器,如果定義這個,batch_size、shuffle、sampler、drop_last都不能定義。實際上,如果沒有特殊的資料生成順序的要求,取樣器並沒有必要定義。torch.utils.data 中的各種 Sampler 就是取樣器類,如果需要,可以使用它們來定義。

  8、用於生成資料的子程式數。預設為0,不併行。

  9、拼接多個樣本的方法,預設是將每個batch的資料在第一維上進行拼接。這樣可能說不清楚,並且由於這裡可以探究一下獲取資料的速度,後面再詳細說明。

  10、是否使用鎖頁記憶體。用的話會更快,記憶體不充足最好別用。

  11、是否把最後小於batch的資料丟掉。

  12、迭代獲取資料並輸出。

速度探索

  首先看一下DEMO的輸出:

  輸出了兩個batch的資料,每組資料中data和label都正確排列,符合我們的預期。那麼DataLoader是怎麼把資料整合起來的呢?首先,我們把collate_fn定義為直接對映(不用它預設的方法),來檢視看每次DataLoader從MyDataSet中讀取了什麼,將上面部分程式碼修改如下:

my_data_loader = DataLoader(
  dataset=my_data_set,    
  batch_size=2,           
  shuffle=True,           
  sampler=None,         
  batch_sampler=None,    
  num_workers=0 ,        
  collate_fn=lambda x:x,  #修改處
  pin_memory=True,       
  drop_last=True         
)

  結果如下:

  輸出還是兩個batch,然而每個batch中,單個的data和label是在一個list中的。似乎可以看出,DataLoader是一個一個讀取MyDataSet中的資料的,然後再進行相應資料的拼接。為了驗證這點,程式碼修改如下:

import torch 
from torch.utils.data import Dataset,DataLoader 
   
class MyDataSet(Dataset): 
  def __init__(self):    
    self.data = torch.tensor(range(10)).reshape([5,2])
    self.label = torch.tensor(range(5))

  def __getitem__(self, index):   
    print(index)          #修改處2
    return self.data[index], self.label[index]

  def __len__(self):    
    return len(self.data)
  
my_data_set = MyDataSet() 
my_data_loader = DataLoader(
  dataset=my_data_set,    
  batch_size=2,           
  shuffle=True,           
  sampler=None,         
  batch_sampler=None,    
  num_workers=0 ,        
  collate_fn=lambda x:x,  #修改處1
  pin_memory=True,       
  drop_last=True         
)

for i in my_data_loader:  
  print(i) 

  輸出如下:

  驗證了前面的猜想,的確是一個一個讀取的。如果資料集定義的不是格式化的資料,那還好,但是我這裡定義的是tensor,是可以直接通過列表來索引對應的tensor的。因此,DataLoader的操作比直接索引多了拼接這一步,肯定是會慢很多的。一兩次的讀取還好,但在訓練中,大量的讀取累加起來,就會浪費很多時間了。

  自定義一個DataLoader可以證明這一點,程式碼如下:

import torch 
from torch.utils.data import Dataset,DataLoader 
from time  import time
   
class MyDataSet(Dataset): 
  def __init__(self):    
    self.data = torch.tensor(range(100000)).reshape([50000,2])
    self.label = torch.tensor(range(50000))

  def __getitem__(self, index):    
    return self.data[index], self.label[index]

  def __len__(self):    
    return len(self.data)

# 自定義DataLoader
class MyDataLoader():
  def __init__(self, dataset,batch_size):
    self.dataset = dataset
    self.batch_size = batch_size
  def __iter__(self):
    self.now = 0
    self.shuffle_i = np.array(range(self.dataset.__len__())) 
    np.random.shuffle(self.shuffle_i)
    return self
 
  def __next__(self): 
    self.now += self.batch_size
    if self.now <= len(self.shuffle_i):
      indexes = self.shuffle_i[self.now-self.batch_size:self.now]
      return self.dataset.__getitem__(indexes)
    else:
      raise StopIteration

# 使用官方DataLoader
my_data_set = MyDataSet() 
my_data_loader = DataLoader(
  dataset=my_data_set,    
  batch_size=256,           
  shuffle=True,           
  sampler=None,         
  batch_sampler=None,    
  num_workers=0 ,        
  collate_fn=None,  
  pin_memory=True,       
  drop_last=True         
)

start_t = time()
for t in range(10):
  for i in my_data_loader:  
    pass
print("官方:", time() - start_t)
 
 
#自定義DataLoader
my_data_set = MyDataSet() 
my_data_loader = MyDataLoader(my_data_set,256)

start_t = time()
for t in range(10):
  for i in my_data_loader:  
    pass
print("自定義:", time() - start_t)

  執行結果如下:

  以上使用batch大小為256,僅各讀取10 epoch的資料,都有30多倍的時間上的差距,更大的batch差距會更明顯。另外,這裡用於測試的每個資料只有兩個浮點數,如果是影像,所需的時間可能會增加幾百倍。因此,如果資料量和batch都比較大,並且資料是格式化的,最好自己寫資料生成器。

並行式讀取

DEMO程式碼

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader 
from torchvision import transforms 
from torchvision.datasets import ImageFolder  
  
path = r'E:\DataSets\ImageNet\ILSVRC2012_img_train\10-19\128x128'
my_data_set = ImageFolder(            #————1————
  root = path,                        #————2————
  transform = transforms.Compose([    #————3————
    transforms.ToTensor(),
    transforms.CenterCrop(64)
  ]),
  loader = plt.imread                 #————4————
)
my_data_loader = DataLoader(
  dataset=my_data_set,      
  batch_size=128,             
  shuffle=True,             
  sampler=None,             
  batch_sampler=None,        
  num_workers=0,            
  collate_fn=None,           
  pin_memory=True,           
  drop_last=True 
)           

for i in my_data_loader: 
  print(i)

  註釋處解釋如下:

  1/2、ImageFolder類繼承自DataSet類,因此可以按索引讀取影像。路徑必須包含資料夾,ImageFolder會給每個資料夾中的影像新增索引,並且每張影像會給予其所在資料夾的標籤。舉個例子,程式碼中my_data_set[0] 輸出的是影像物件和它對應的標籤組成的列表。

  3、影像到格式化資料的轉換組合。更多的轉換方法可以看 transform 模組。

  4、影像法的讀取方式,預設是PIL.Image.open(),但我發現plt.imread()更快一些。

  由於是邊訓練邊讀取,transform會佔用很多時間,因此可以先將影像轉換為需要的形式存入外存再讀取,從而避免重複操作。

  其中transform.ToTensor()會把正常讀取的影像轉換為torch.tensor,並且畫素值會對映至$[0,1]$。由於plt.imread()讀取png影像時,畫素值在$[0,1]$,而讀取jpg影像時,畫素值卻在$[0,255]$,因此使用transform.ToTensor()能將影像畫素區間統一化。

相關文章