PyTorch使用總覽

Terminator2050發表於2018-09-03

 

PyTorch使用總覽

 

https://blog.csdn.net/u014380165/article/details/79222243

 

深度學習框架訓練模型時的程式碼主要包含資料讀取、網路構建和其他設定三方面,基本上掌握這三方面就可以較為靈活地使用框架訓練模型。PyTorch是Facebook的官方深度學習框架之一,到現在開源1年時間,勢頭非常猛,相信使用過的人都會被其輕便和快速等特點深深吸引,因此這篇部落格從整體上介紹如何使用PyTorch

PyTorch的官方github地址:https://github.com/pytorch/pytorch
PyTorch官方文件:http://pytorch.org/docs/0.3.0/

 

建議先看看:PyTorch學習之路(level1)——訓練一個影像分類模型,對Pytorch的使用有一個快速的瞭解。

 

接下來就按照上述的3個方面來介紹如何使用PyTorch。

 

一、資料讀取

 

資料讀取部分包含如何將你的影像和標籤資料轉換成PyTorch框架的Tensor資料型別,官方程式碼庫中有一個介面例子:torchvision.ImageFolder,這個介面在PyTorch學習之路(level1)——訓練一個影像分類模型 中有簡單介紹。因為這個介面針對的資料存放方式是每個資料夾包含一個類的影像,但是實際應用中可能你的資料不是這樣維護的,或者你的資料是多標籤的,或者其他更復雜的形式,那麼就需要自定義一個資料讀取介面,這個時候就不得不提一個PyTorch中資料讀取基類:torch.utils.data.Dataset,包括前面提到的torchvision.ImageFolder介面的對應類也是繼承torch.utils.data.Dataset實現的,因此torch.utils.data.Dataset類是PyTorch框架中資料讀取的核心。那麼如何自定義一個資料讀取介面呢?可以看部落格: PyTorch學習之路(level2)——自定義資料讀取,這篇部落格中從剖析torchvision.ImageFolder介面切入,然後引出如何自定義資料讀取介面。這樣就完成了資料的第一層封裝。

 

在自定義資料讀取介面時還有一步很重要的操作:資料預處理。常常我們在論文中看到的data argumentation就是指的資料預處理,對實驗結果影響還是比較大的。該操作在PyTorch中可以通過torchvision.transforms介面來實現,具體請看部落格:PyTorch原始碼解讀之torchvision.transforms 的介紹。

 

經過上述的兩個操作後,還需再進行一次封裝,將資料和標籤封裝成資料迭代器,這樣才方便模型訓練的時候一個batch一個batch地進行,這就要用到torch.utils.data.DataLoader介面,該介面的一個輸入就是前面繼承自torch.utils.data.Dataset類的自定義了的物件(比如torchvision.ImageFolder類的物件),具體可以參考部落格: PyTorch原始碼解讀之torch.utils.data.DataLoader

至此,從影像和標籤檔案就生成了Tensor型別的資料迭代器,後續僅需將Tensor物件用torch.autograd.Variable介面封裝成Variable型別(比如train_data=torch.autograd.Variable(train_data),如果要在gpu上執行則是:train_data=torch.autograd.Variable(train_data.cuda()))就可以作為模型的輸入了。

 

其他自定義的資料讀取介面例子可以參考:https://github.com/miraclewkf/MobileNetV2-PyTorch,該專案中的read_ImageNetData.py指令碼自定義了讀取ImageNet資料集的介面,訓練資料的讀取和驗證資料的讀取採取不同的介面實現,比較有特點。

 

二、網路構建

 

PyTorch框架中提供了一些方便使用的網路結構及預訓練模型介面:torchvision.models,具體可以看部落格:PyTorch原始碼解讀之torchvision.models。該介面可以直接匯入指定的網路結構,並且可以選擇是否用預訓練模型初始化匯入的網路結構。

 

那麼如何自定義網路結構呢?在PyTorch中,構建網路結構的類都是基於torch.nn.Module這個基類進行的,也就是說所有網路結構的構建都可以通過繼承該類來實現,包括torchvision.models介面中的模型實現類也是繼承這個基類進行重寫的。自定義網路結構可以參考:1、https://github.com/miraclewkf/MobileNetV2-PyTorch。該專案中的MobileNetV2.py指令碼自定義了網路結構。2、https://github.com/miraclewkf/SENet-PyTorch。該專案中的se_resnet.py和se_resnext.py指令碼分別自定義了不同的網路結構。

 

如果要用某預訓練模型為自定義的網路結構進行引數初始化,可以用torch.load介面匯入預訓練模型,然後呼叫自定義的網路結構物件的load_state_dict方式進行引數初始化,具體可以看https://github.com/miraclewkf/MobileNetV2-PyTorch專案中的train.py指令碼中if args.resume條件語句。

 

三、其他設定

 

優化函式通過torch.optim包實現,比如torch.optim.SGD()介面表示隨機梯度下降。更多優化函式可以看官方文件:http://pytorch.org/docs/0.3.0/optim.html

 

學習率策略通過torch.optim.lr_scheduler介面實現,比如torch.optim.lr_scheduler.StepLR()介面表示按指定epoch數減少學習率。更多學習率變化策略可以看官方文件:http://pytorch.org/docs/0.3.0/optim.html

 

損失函式通過torch.nn包實現,比如torch.nn.CrossEntropyLoss()介面表示交叉熵等。

 

多GPU訓練通過torch.nn.DataParallel介面實現,比如:model = torch.nn.DataParallel(model, device_ids=[0,1])表示在gpu0和1上訓練模型。

 

相關文章