從零搭建Pytorch模型教程(一)資料讀取

CV技術指南(公眾號)發表於2022-03-04

 前言 

本文介紹了classdataset的幾個要點,由哪些部分組成,每個部分需要完成哪些事情,如何進行資料增強,如何實現自己設計的資料增強。然後,介紹了分散式訓練的資料載入方式,資料讀取的整個流程,當面對超大資料集時,記憶體不足的改進思路。

本文延續了以往的寫作態度和風格,即便是自己知道的內容,也仍然在寫之前看了很多的文章來保證內容的正確性和全面性,因此寫得極累,耗費時間較長。若有讀者看完後覺得有所幫助,文末可以讚賞一點。

文末掃描二維碼關注公眾號CV技術指南 ,專注於計算機視覺的技術總結、最新技術跟蹤、經典論文解讀,招聘資訊釋出。

 

(零) 概述


浮躁是人性的一個典型的弱點,很多人總擅長看別人分享的現成程式碼解讀的文章,看起來學會了好多東西,實際上仍然不具備自己從零搭建一個pipeline的能力。

在公眾號(CV技術指南)的交流群裡(群內交流氛圍不錯,有需要的請關注公眾號加群),常有不少人問到一些問題,根據這些問題明顯能看出是對pipeline不瞭解,卻已經在搞專案或論文了,很難想象如果基本的pipeline都不懂,如何分析程式碼問題所在?如何分析結果不正常的可能原因?遇到問題如何改?

Pytorch在這幾年逐漸成為了學術上的主流框架,其具有簡單易懂的特點。網上有很多pytorch的教程,如果是一個已經懂的人去看這些教程,確實pipeline的要素都寫到了,感覺這教程挺不錯的。但實際上更多地像是寫給自己看的一個筆記,記錄了pipeline要寫哪些東西,卻沒有介紹要怎麼寫,為什麼這麼寫,剛入門的小白看的時候容易雲裡霧裡。

鑑於此,本教程嘗試對於pytorch搭建一個完整pipeline寫一個比較明確且易懂的說明。

本教程將介紹以下內容:

  1. 準備資料,自定義classdataset,分散式訓練的資料載入方式,載入超大資料集的改進思路。

  2. 搭建模型與模型初始化。

  3. 編寫訓練過程,包括載入預訓練模型、設定優化器、設定損失函式等。

  4. 視覺化並儲存訓練過程。

  5. 編寫推理函式。

     

(一)資料讀取


classdataset的定義

先來看一個完整的classdataset

import torch.utils.data as data
import torchvision.transforms as transforms

class MyDataset(data.Dataset):
   def __init__(self,data_folder):
       self.data_folder = data_folder
       self.filenames = []
       self.labels = []

       per_classes = os.listdir(data_folder)
       for per_class in per_classes:
           per_class_paths = os.path.join(data_folder, per_class)
           label = torch.tensor(int(per_class))

           per_datas = os.listdir(per_class_paths)
           for per_data in per_datas:
               self.filenames.append(os.path.join(per_class_paths, per_data))
               self.labels.append(label)

   def __getitem__(self, index):
       image = Image.open(self.filenames[index])
       label = self.labels[index]
       data = self.proprecess(image)
       return data, label

   def __len__(self):
       return len(self.filenames)

   def proprecess(self,data):
       transform_train_list = [
           transforms.Resize((self.opt.h, self.opt.w), interpolation=3),
           transforms.Pad(self.opt.pad, padding_mode='edge'),
           transforms.RandomCrop((self.opt.h, self.opt.w)),
           transforms.RandomHorizontalFlip(),
           transforms.ToTensor(),
           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
      ]
       return transforms.Compose(transform_train_list)

  

classdataset的幾個要點:

  1. classdataset類繼承torch.utils.data.dataset。

  2. classdataset的作用是將任意格式的資料,通過讀取、預處理或資料增強後以tensor的形式輸出。其中任意格式的資料可能是以資料夾名作為類別的形式、或以txt檔案儲存圖片地址的形式、或視訊、或十幾幀影像作為一份樣本的形式。而輸出則指的是經過處理後的一個batch的tensor格式資料和對應標籤。

  3. classdataset主要有三個函式要完成:__init__函式、__getitem__ 函式和__len__函式。

 

__init__函式

init函式主要是完成兩個靜態變數的賦值。一個是用於儲存所有資料路徑的變數,變數的每個元素即為一份訓練樣本,(注:如果一份樣本是十幾幀影像,則變數每個元素儲存的是這十幾幀影像的路徑),可以命名為self.filenames。一個是用於儲存與資料路徑變數一一對應的標籤變數,可以命名為self.labels。

假如資料集的格式如下:

#這裡的0,1指的是類別0,1
/data_path/0/image0.jpg
/data_path/0/image1.jpg
/data_path/0/image2.jpg
/data_path/0/image3.jpg
......
/data_path/1/image0.jpg
/data_path/1/image1.jpg
/data_path/1/image2.jpg
/data_path/1/image3.jpg

  

可通過per_classes = os.listdir(data_path) 獲得所有類別的資料夾,在此處per_classes的每個元素即為對應的資料標籤,通過for遍歷per_classes即可獲得每個類的標籤,將其轉換成int的tensor形式即可。在for下獲得每個類下每張圖片的路徑,通過self.join獲得每份樣本的路徑,通過append新增到self.filenames中。

 

__getitem__ 函式

getitem 函式主要是根據索引返回對應的資料。這個索引是在訓練前通過dataloader切片獲得的,這裡先不管。它的引數預設是index,即每次傳回在init函式中獲得的所有樣本中索引對應的資料和標籤。因此,可通過下面兩行程式碼找到對應的資料和標籤。

image = Image.open(self.filenames[index]))
label = self.labels[index]

  

獲得資料後,進行資料預處理。資料預處理主要通過 torchvision.transforms 來完成,這裡面已經包含了常用的預處理、資料增強方式。其完整使用方式在官網有詳細介紹:https://pytorch.org/vision/stable/transforms.html

上面這裡介紹了最常用的幾種,主要就是resize,隨機裁剪,翻轉,歸一化等。

最後通過transforms.Compose(transform_train_list)來執行。

 

除了這些已經有的資料增強方式外,在《資料增強方法總結》中還介紹了十幾種特殊的資料增強方式,像這種自己設計了一種新的資料增強方式,該如何新增進去呢

下面以隨機擦除作為例子。

class RandomErasing(object):
   """ Randomly selects a rectangle region in an image and erases its pixels.
      'Random Erasing Data Augmentation' by Zhong et al.
      See https://arxiv.org/pdf/1708.04896.pdf
  Args:
        probability: The probability that the Random Erasing operation will be performed.
        sl: Minimum proportion of erased area against input image.
        sh: Maximum proportion of erased area against input image.
        r1: Minimum aspect ratio of erased area.
        mean: Erasing value.
  """
   def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
       self.probability = probability
       self.mean = mean
       self.sl = sl
       self.sh = sh
       self.r1 = r1

   def __call__(self, img):
       if random.uniform(0, 1) > self.probability:
           return img
       for attempt in range(100):
           area = img.size()[1] * img.size()[2]
           target_area = random.uniform(self.sl, self.sh) * area
           aspect_ratio = random.uniform(self.r1, 1 / self.r1)
           h = int(round(math.sqrt(target_area * aspect_ratio)))
           w = int(round(math.sqrt(target_area / aspect_ratio)))
           if w < img.size()[2] and h < img.size()[1]:
               x1 = random.randint(0, img.size()[1] - h)
               y1 = random.randint(0, img.size()[2] - w)
               if img.size()[0] == 3:
                   img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
                   img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
                   img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
               else:
                   img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
               return img
       return img

  

從零搭建Pytorch模型教程(一)資料讀取

如上所示,自己寫一個類RandomErasing,繼承object,在call函式裡完成你的操作。在transform_train_list裡新增上RandomErasing的定義即可。

transform_train_list = [
          transforms.Resize((self.opt.h, self.opt.w), interpolation=3),
          transforms.Pad(self.opt.pad, padding_mode='edge'),
          transforms.RandomCrop((self.opt.h, self.opt.w)),
          transforms.RandomHorizontalFlip(),
          transforms.ToTensor(),
          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
          RandomErasing(probability=self.opt.erasing_p, mean=[0.0, 0.0, 0.0])
          #新增到這裡
      ]

  

 

__len__函式

len函式主要就是返回資料長度,即樣本的總數量。前面介紹了self.filenames的每個元素即為每份樣本的路徑,因此,self.filename的長度就是樣本的數量。通過return len(self.filenames)即可返回資料長度。

 

驗證classdataset

train_dataset = My_Dataset(data_folder=data_folder)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
print('there are total %s batches for train' % (len(train_loader)))

for i,(data,label) in enumerate(train_loader):
    print(data.size(),label.size())

  

分散式訓練的資料載入方式


前面介紹的是單卡的資料載入,實際上分散式也是這樣,但為了高速高效讀取,每張卡上也會儲存所有資料的資訊,即self.filenames和self.labels的資訊。只是在DistributedSampler 中會給每張卡分配互不交叉的索引,然後由torch.utils.data.DataLoader來載入。

dataset = My_Dataset(data_folder=data_folder)
sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)

 

 

資料讀取的完整流程


結合上面這段程式碼,在這裡,我們介紹以下讀取資料的整個流程。

  1. 首先定義一個classdataset,在初始化函式裡獲得所有資料的資訊。

  2. classdataset中實現getitem函式,通過索引來獲取對應的資料,然後對資料進行預處理和資料增強。

  3. 在模型訓練前,初始化classdataset,通過Dataloader來載入資料,其載入方式是通過Dataloader中分配的索引,呼叫getitem函式來獲取。

    關於索引的分配,在單卡上,可通過設定shuffle=True來隨機生成索引順序;在多機多卡的分散式訓練上,shuffle操作通過DistributedSampler來完成,因此shuffle與sampler只能有一個,另一個必須為None。

 

超大資料集的載入思路


問題所在

再回顧一下上面這個流程,前面提到所有資料資訊在classdataset初始化部分都會儲存在變數中,因此當面對超大資料集時,會出現記憶體不足的情況。

思路

將切片獲取索引的步驟放到classdataset初始化的位置,此時每張卡都是儲存不同的資料子集。通過這種方式,可以將記憶體用量減少到原來的world_size倍(world_size指卡的數量)。

參考程式碼

class RankDataset(Dataset):
   '''
  實際流程
  獲取rank和world_size 資訊 -> 獲取dataset長度 -> 根據dataset長度產生隨機indices ->
  給不同的rank 分配indices -> 根據這些indices產生metas
  '''
   def __init__(self, meta_file, world_size, rank, seed):
       super(RankDataset, self).__init__()
       random.seed(seed)
       np.random.seed(seed)
       self.world_size = world_size
       self.rank = rank
       self.metas = self.parse(meta_file)

   def parse(self, meta_file):
       dataset_size = self.get_dataset_size(meta_file)                                     # 獲取metafile的行數
       local_rank_index = self.get_local_index(dataset_size, self.rank, self.world_size)   # 根據world size和rank,獲取當前epoch,當前rank需要訓練的index。
       self.metas = self.read_file(meta_file, local_rank_index)


   def __getitem__(self, idx):
       return self.metas[idx]

   def __len__(self):
       return len(self.metas)
   
##train
for epoch_num in range(epoch_num):
   dataset = RankDataset("/path/to/meta", world_size, rank, seed=epoch_num)
   sampler = RandomSampler(datset)
   dataloader = DataLoader(
               dataset=dataset,
               batch_size=32,
               shuffle=False,
               num_workers=4,
               sampler=sampler)

  

但這種思路比較明顯的問題時,為了讓每張卡上在每個epoch都載入不同的訓練子集,因此需要在每個epoch重新build dataloader。

這一節參考連結:https://zhuanlan.zhihu.com/p/357809861

 

總結


本篇文章介紹了資料讀取的完整流程,如何自定義classdataset,如何進行資料增強,自己設計的資料增強如何寫,分散式訓練是如何載入資料的,超大資料集的資料載入改進思路。

相信讀完本文的讀者對資料讀取有了比較清晰的認識,下一篇將介紹搭建模型與模型初始化。

關注公眾號可加計算機視覺交流群

歡迎關注公眾號 CV技術指南 ,專注於計算機視覺的技術總結、最新技術跟蹤、經典論文解讀。

在公眾號中回覆關鍵字 “入門指南“可獲取計算機視覺入門所有必備資料。

從零搭建Pytorch模型教程(一)資料讀取

 

其它文章

 

自編碼器綜述論文:概念、圖解和應用

解決影像分割落地場景真實問題,港中文等提出:開放世界實體分割

資源分享 |Nebullvm:一行程式碼測試多個DL編譯器,模型推理提高5-20倍

目標檢測、例項分割、多目標跟蹤的Anchor-free應用方法總結

Soft Sampling:探索更有效的取樣策略

如何解決工業缺陷檢測小樣本問題

機器學習、深度學習面試知識點彙總

深度學習影像識別的未來:機遇與挑戰並存

招聘 | 22-65k!遷移科技:招聘深度學習、傳統視覺、3D視覺演算法工程師、專案經理、機械設計

關於快速學習一項新技術或新領域的一些個人思維習慣與思想總結

計算機視覺中的影像標註工具總結

計算機視覺中的神經網路視覺化工具與專案

計算機視覺中的高效閱讀論文的方法總結

計算機視覺中的transformer模型創新思路總結

一文概括機器視覺常用演算法以及常用開發庫

HOG和SIFT影像特徵提取簡述    |  特徵金字塔技術總結

目標檢測中迴歸損失函式總結    |    例項分割綜述總結綜合整理版

2021年小目標檢測最新研究綜述    |    小目標檢測常用方法總結

相關文章