【小白學PyTorch】6 模型的構建訪問遍歷儲存(附程式碼)

忽逢桃林發表於2020-09-05

文章轉載自微信公眾號:機器學習煉丹術。歡迎大家關注,這是我的學習分享公眾號,100+原創乾貨。

文章目錄:

本文是對一些函式的學習。函式主要包括下面四個方便:

  • 模型構建的函式:add_module,add_module,add_module
  • 訪問子模組:add_module,add_module,add_moduleadd_module
  • 網路遍歷:
    add_module,add_module
  • 模型的儲存與載入:add_module,add_module,add_module

1 模型構建函式

torch.nn.Module是所有網路的基類,在PyTorch實現模型的類中都要繼承這個類(這個在之前的課程中已經提到)。在構建Module中,Module是一個包含其他的Module的,類似於,你可以先定義一個小的網路模組,然後把這個小模組作為另外一個網路的元件。因此網路結構是呈現樹狀結構

我們先簡單定義一個網路:

import torch.nn as nn
import torch 
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv1 = nn.Conv2d(3,64,3)
        self.conv2 = nn.Conv2d(64,64,3)

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x
net = MyNet()
print(net)

輸出結果:

MyNet中有兩個屬性conv1conv2是兩個卷積層,在正向傳播forward的過程中,依次呼叫這兩個卷積層實現網路的功能。

1.1 add_module

這種是最常見的定義網路的功能,在有些專案中,會看到這樣的方法add_module。我們用這個方法來重寫上面的網路:

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.add_module('conv1',nn.Conv2d(3,64,3))
        self.add_module('conv2',nn.Conv2d(64,64,3))

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

其實add_module(name,layer)self.name=layer實現了相同的功能,個人感覺也許是因為add_module可以使用字串來定義變數名字,所以可以放在迴圈中?反正這個先了解熟悉熟悉

上面的兩種方法都是一層一層的新增layer,如果網路複雜的話,那就需要寫很多重複的程式碼了。因此接下來來講解一下網路模組的構建,torch.nn.ModuleListtorch.nn.Sequential

1.2 ModuleList

ModuleList按照字面意思是用list的形式儲存網路層的。這樣就可以先將網路需要的layer構建好,儲存到一個list,然後通過ModuleList方法新增到網路中.

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.linears = nn.ModuleList(
            [nn.Linear(10,10) for i in range(5)]
        )

    def forward(self,x):
        for l in self.linears:
            x = l(x)
        return x
net = MyNet()
print(net)

輸出結果是:

這個ModuleList主要是用在讀取config檔案來構建網路模型中的,下面用VGG模型的構建為例子:

vgg_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
           512, 512, 512, 'M']

def vgg(cfg, i, batch_norm=False):
    layers = []
    in_channels = i
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        elif v == 'C':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return layers

class Model1(nn.Module):
    def __init__(self):
        super(Model1,self).__init__()

        self.vgg = nn.ModuleList(vgg(vgg_cfg,3))

    def forward(self,x):

        for l in self.vgg:
            x = l(x)
m1 = Model1()
print(m1)

先讀取網路結構的配置檔案vgg_cfg然後根據這個檔案建立對應的Layer list,然後使用ModuleList新增到網路中,這樣可以快速建立不同的網路(用上面為例子的話,可以通過修改配置檔案,然後快速修改網路結構

1.3 Sequential

在一些自己做的小專案中,Sequential其實用的更為頻繁。
依然重寫最初最簡單的例子:

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3,64,3),
            nn.Conv2d(64,64,3)
        )

    def forward(self,x):
        x = self.conv(x)
        return x
net = MyNet()
print(net)

執行結果:

觀察細緻的朋友可以發現這個問題,Seqential內的網路層是預設用數字進行標號的,而一開始我們使用self.conv1self.conv2的時候,使用conv1和conv2作為標號的。

我們如何修改Sequential中網路層的名稱呢?這裡需要使用到collections.OrderedDict有序字典。Sequential是支援有序字典構建的。

from collections import OrderedDict 
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv = nn.Sequential(OrderedDict([
            ('conv1',nn.Conv2d(3,64,3)),
            ('conv2',nn.Conv2d(64,64,3))
        ]))

    def forward(self,x):
        x = self.conv(x)
        return x
net = MyNet()
print(net)

輸出結果:

1.4 小總結

  • 單獨增加一個網路層或者子模組,可以用add_module或者直接賦予屬性;
  • ModuleList可以將一個Module的List增加到網路中,自由度較高。
  • Sequential按照順序產生一個Module模組。這裡推薦習慣使用OrderedDict的方法進行構建。對網路層加上規範的名稱,這樣有助於後續查詢與遍歷

2 遍歷模型結構

本章節使用下面的方法進行遍歷之前提到的Module。(個人理解,Module是多個layer的合併,但是一個layer可以說成Module。
先定義一個網路吧,隨便寫一個:

import torch.nn as nn
import torch 
from collections import OrderedDict
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)
        self.conv2 = nn.Conv2d(64,64,3)
        self.maxpool1 = nn.MaxPool2d(2,2)

        self.features = nn.Sequential(OrderedDict([
            ('conv3', nn.Conv2d(64,128,3)),
            ('conv4', nn.Conv2d(128,128,3)),
            ('relu1', nn.ReLU())
        ]))

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.maxpool1(x)
        x = self.features(x)

        return x
net = MyNet()
print(net)

輸出結果是:

2.1 modules()

在第四課中初始化模型各個層的引數的時候,用到了這個方法,現在我們再來理解一下:

for idx,m in enumerate(net.modules()):
    print(idx,"-",m)

執行結果:

上面那個網路構建的時候用到了Sequential,所以網路中其實是巢狀了一個小的Module,這就是之前提到的樹狀結構,然後上面便利的時候也是樹狀結構的便利過程,可以看出來應該是一個深度遍歷的過程。

  • 首先第一個輸出的是最大的那個Module,也就是整個網路,0-Model整個網路模組;
  • 1-2-3-4是網路的四個子模組,4-Sequential中間仍然包含子模組
  • 5-6-7是模組4-Sequential的子模組。

【總結】

modules()是遞迴的返回網路的各個module(深度遍歷),從最頂層直到最後的葉子的module。

2.2 named_modules()

named_modules()module()類似,只是同時返回name和module。

for idx,(name,m) in enumerate(net.named_modules()):
    print(idx,"-",name)

輸出結果:

2.3 parameters()

for p in net.parameters():
    print(type(p.data),p.size())

執行結果:

輸出的是四個卷積層的權重矩陣引數和偏置引數。值得一提的是,對網路進行訓練時需要將parameters()作為優化器optimizer的引數。

optimizer = torch.optim.SGD(net.parameters(),
                            lr = 0.001,
                            momentum=0.9)

總之呢,這個parameters()是返回網路所有的引數,主要用在給optimizer優化器用的。而要對網路的某一層的引數做處理的時候,一般還是使用named_parameters()方便一些。

for idx,(name,m) in enumerate(net.named_parameters()):
    print(idx,"-",name,m.size())

輸出結果:

【小擴充套件】

我個人有時會使用下面的方法來獲取引數:

for idx,(name,m) in enumerate(net.named_modules()):
    if isinstance(m,nn.Conv2d):
        print(m.weight.shape)
        print(m.bias.shape)

先判斷是否是卷積層,然後獲取其引數,輸出結果:

3 儲存與載入

PyTorch使用torch.savetorch.load方法來儲存和載入網路,而且網路結構和引數可以分開的儲存和載入。

torch.save(model,'model.pth') # 儲存
model = torch.load("model.pth") # 載入

pytorch中網路結構和模型引數是可以分開儲存的。上面的方法是兩者同時儲存到了.pth檔案中,當然,你也可以僅僅儲存網路的引數來減小儲存檔案的大小。注意:如果你僅僅儲存模型引數,那麼在載入的時候,是需要通過執行程式碼來初始化模型的結構的。

torch.save(model.state_dict(),"model.pth") # 儲存引數
model = MyNet() # 程式碼中建立網路結構
params = torch.load("model.pth") # 載入引數
model.load_state_dict(params) # 應用到網路結構中

至此,我們今天已經學習了不少的內容,大家對PyTorch的掌握更近一步了呢~

相關文章