前言
本文介紹了classdataset的幾個要點,由哪些部分組成,每個部分需要完成哪些事情,如何進行資料增強,如何實現自己設計的資料增強。然後,介紹了分散式訓練的資料載入方式,資料讀取的整個流程,當面對超大資料集時,記憶體不足的改進思路。
本文延續了以往的寫作態度和風格,即便是自己知道的內容,也仍然在寫之前看了很多的文章來保證內容的正確性和全面性,因此寫得極累,耗費時間較長。若有讀者看完後覺得有所幫助,文末可以讚賞一點。
文末掃描二維碼關注公眾號CV技術指南 ,專注於計算機視覺的技術總結、最新技術跟蹤、經典論文解讀,招聘資訊釋出。
(零) 概述
浮躁是人性的一個典型的弱點,很多人總擅長看別人分享的現成程式碼解讀的文章,看起來學會了好多東西,實際上仍然不具備自己從零搭建一個pipeline的能力。
在公眾號(CV技術指南)的交流群裡(群內交流氛圍不錯,有需要的請關注公眾號加群),常有不少人問到一些問題,根據這些問題明顯能看出是對pipeline不瞭解,卻已經在搞專案或論文了,很難想象如果基本的pipeline都不懂,如何分析程式碼問題所在?如何分析結果不正常的可能原因?遇到問題如何改?
Pytorch在這幾年逐漸成為了學術上的主流框架,其具有簡單易懂的特點。網上有很多pytorch的教程,如果是一個已經懂的人去看這些教程,確實pipeline的要素都寫到了,感覺這教程挺不錯的。但實際上更多地像是寫給自己看的一個筆記,記錄了pipeline要寫哪些東西,卻沒有介紹要怎麼寫,為什麼這麼寫,剛入門的小白看的時候容易雲裡霧裡。
鑑於此,本教程嘗試對於pytorch搭建一個完整pipeline寫一個比較明確且易懂的說明。
本教程將介紹以下內容:
-
準備資料,自定義classdataset,分散式訓練的資料載入方式,載入超大資料集的改進思路。
-
搭建模型與模型初始化。
-
編寫訓練過程,包括載入預訓練模型、設定優化器、設定損失函式等。
-
視覺化並儲存訓練過程。
-
編寫推理函式。
(一)資料讀取
classdataset的定義
先來看一個完整的classdataset
classdataset的幾個要點:
-
classdataset類繼承torch.utils.data.dataset。
-
classdataset的作用是將任意格式的資料,通過讀取、預處理或資料增強後以tensor的形式輸出。其中任意格式的資料可能是以資料夾名作為類別的形式、或以txt檔案儲存圖片地址的形式、或視訊、或十幾幀影像作為一份樣本的形式。而輸出則指的是經過處理後的一個batch的tensor格式資料和對應標籤。
-
classdataset主要有三個函式要完成:__init__函式、__getitem__ 函式和__len__函式。
__init__函式
init函式主要是完成兩個靜態變數的賦值。一個是用於儲存所有資料路徑的變數,變數的每個元素即為一份訓練樣本,(注:如果一份樣本是十幾幀影像,則變數每個元素儲存的是這十幾幀影像的路徑),可以命名為self.filenames。一個是用於儲存與資料路徑變數一一對應的標籤變數,可以命名為self.labels。
假如資料集的格式如下:
可通過per_classes = os.listdir(data_path) 獲得所有類別的資料夾,在此處per_classes的每個元素即為對應的資料標籤,通過for遍歷per_classes即可獲得每個類的標籤,將其轉換成int的tensor形式即可。在for下獲得每個類下每張圖片的路徑,通過self.join獲得每份樣本的路徑,通過append新增到self.filenames中。
__getitem__ 函式
getitem 函式主要是根據索引返回對應的資料。這個索引是在訓練前通過dataloader切片獲得的,這裡先不管。它的引數預設是index,即每次傳回在init函式中獲得的所有樣本中索引對應的資料和標籤。因此,可通過下面兩行程式碼找到對應的資料和標籤。
獲得資料後,進行資料預處理。資料預處理主要通過 torchvision.transforms 來完成,這裡面已經包含了常用的預處理、資料增強方式。其完整使用方式在官網有詳細介紹:https://pytorch.org/vision/stable/transforms.html
上面這裡介紹了最常用的幾種,主要就是resize,隨機裁剪,翻轉,歸一化等。
最後通過transforms.Compose(transform_train_list)來執行。
除了這些已經有的資料增強方式外,在《
下面以隨機擦除作為例子。
如上所示,自己寫一個類RandomErasing,繼承object,在call函式裡完成你的操作。在transform_train_list裡新增上RandomErasing的定義即可。
transform_train_list = [ transforms.Resize((self.opt.h, self.opt.w), interpolation=3), transforms.Pad(self.opt.pad, padding_mode='edge'), transforms.RandomCrop((self.opt.h, self.opt.w)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) RandomErasing(probability=self.opt.erasing_p, mean=[0.0, 0.0, 0.0]) #新增到這裡 ]
__len__函式
len函式主要就是返回資料長度,即樣本的總數量。前面介紹了self.filenames的每個元素即為每份樣本的路徑,因此,self.filename的長度就是樣本的數量。通過return len(self.filenames)即可返回資料長度。
驗證classdataset
分散式訓練的資料載入方式
前面介紹的是單卡的資料載入,實際上分散式也是這樣,但為了高速高效讀取,每張卡上也會儲存所有資料的資訊,即self.filenames和self.labels的資訊。只是在DistributedSampler 中會給每張卡分配互不交叉的索引,然後由torch.utils.data.DataLoader來載入。
dataset = My_Dataset(data_folder=data_folder) sampler = DistributedSampler(dataset) if is_distributed else None loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
資料讀取的完整流程
結合上面這段程式碼,在這裡,我們介紹以下讀取資料的整個流程。
-
首先定義一個classdataset,在初始化函式裡獲得所有資料的資訊。
-
classdataset中實現getitem函式,通過索引來獲取對應的資料,然後對資料進行預處理和資料增強。
-
在模型訓練前,初始化classdataset,通過Dataloader來載入資料,其載入方式是通過Dataloader中分配的索引,呼叫getitem函式來獲取。
關於索引的分配,在單卡上,可通過設定shuffle=True來隨機生成索引順序;在多機多卡的分散式訓練上,shuffle操作通過DistributedSampler來完成,因此shuffle與sampler只能有一個,另一個必須為None。
超大資料集的載入思路
問題所在
再回顧一下上面這個流程,前面提到所有資料資訊在classdataset初始化部分都會儲存在變數中,因此當面對超大資料集時,會出現記憶體不足的情況。
思路
將切片獲取索引的步驟放到classdataset初始化的位置,此時每張卡都是儲存不同的資料子集。通過這種方式,可以將記憶體用量減少到原來的world_size倍(world_size指卡的數量)。
參考程式碼
class RankDataset(Dataset): ''' 實際流程 獲取rank和world_size 資訊 -> 獲取dataset長度 -> 根據dataset長度產生隨機indices -> 給不同的rank 分配indices -> 根據這些indices產生metas ''' def __init__(self, meta_file, world_size, rank, seed): super(RankDataset, self).__init__() random.seed(seed) np.random.seed(seed) self.world_size = world_size self.rank = rank self.metas = self.parse(meta_file) def parse(self, meta_file): dataset_size = self.get_dataset_size(meta_file) # 獲取metafile的行數 local_rank_index = self.get_local_index(dataset_size, self.rank, self.world_size) # 根據world size和rank,獲取當前epoch,當前rank需要訓練的index。 self.metas = self.read_file(meta_file, local_rank_index) def __getitem__(self, idx): return self.metas[idx] def __len__(self): return len(self.metas) ##train for epoch_num in range(epoch_num): dataset = RankDataset("/path/to/meta", world_size, rank, seed=epoch_num) sampler = RandomSampler(datset) dataloader = DataLoader( dataset=dataset, batch_size=32, shuffle=False, num_workers=4, sampler=sampler)
這一節參考連結:https://zhuanlan.zhihu.com/p/357809861
總結
本篇文章介紹了資料讀取的完整流程,如何自定義classdataset,如何進行資料增強,自己設計的資料增強如何寫,分散式訓練是如何載入資料的,超大資料集的資料載入改進思路。
相信讀完本文的讀者對資料讀取有了比較清晰的認識,下一篇將介紹搭建模型與模型初始化。
關注公眾號可加計算機視覺交流群
歡迎關注公眾號 CV技術指南 ,專注於計算機視覺的技術總結、最新技術跟蹤、經典論文解讀。
在公眾號中回覆關鍵字 “入門指南“可獲取計算機視覺入門所有必備資料。
其它文章