【小白學PyTorch】4 構建模型三要素與權重初始化

忽逢桃林發表於2020-09-03

文章目錄:

1 模型三要素

三要素其實很簡單

  1. 必須要繼承nn.Module這個類,要讓PyTorch知道這個類是一個Module
  2. 在__init__(self)中設定好需要的元件,比如conv,pooling,Linear,BatchNorm等等
  3. 最後在forward(self,x)中用定義好的元件進行組裝,就像搭積木,把網路結構搭建出來,這樣一個模型就定義好了

我們來看一個例子:
先看__init__(self)函式

def __init__(self):
	super(Net,self).__init__()
	self.conv1 = nn.Conv2d(3,6,5)
	self.pool1 = nn.MaxPool2d(2,2)
	self.conv2 = nn.Conv2d(6,16,5)
	self.pool2 = nn.MaxPool2d(2,2)
	self.fc1 = nn.Linear(16*5*5,120)
	self.fc2 = nn.Linear(120,84)
	self.fc3 = nn.Linear(84,10)

第一行是初始化,往後定義了一系列元件。nn.Conv2d就是一般圖片處理的卷積模組,然後池化層,全連線層等等。

定義完這些定義forward函式

def forward(self,x):
	x = self.pool1(F.relu(self.conv1(x)))
	x = self.pool2(F.relu(self.conv2(x)))
	x = x.view(-1,16*5*5)
	x = F.relu(self.fc1(x))
	x = F.relu(self.fc2(x))
	x = self.fc3(x)
	return x

x為模型的輸入,第一行表示x經過conv1,然後經過啟用函式relu,然後經過pool1操作
第三行表示對x進行reshape,為後面的全連線層做準備

至此,對一個模型的定義完畢,如何使用呢?

例如:

net = Net()
outputs = net(inputs)

其實net(inputs),就是類似於使用了net.forward(inputs)這個函式。

2 引數初始化

簡單地說就是設定什麼層用什麼初始方法,初始化的方法會在torch.nn.init中

話不多說,看一個案例:

# 定義權值初始化
def initialize_weights(self):
	for m in self.modules():
		if isinstance(m,nn.Conv2d):
			torch.nn.init.xavier_normal_(m.weight.data)
			if m.bias is not None:
				m.bias.data.zero_()
		elif isinstance(m,nn.BatchNorm2d):
			m.weight.data.fill_(1)
			m.bias.data.zero_()
		elif isinstance(m,nn.Linear):
			torch.nn.init.normal_(m.weight.data,0,0.01)
			# m.weight.data.normal_(0,0.01)
			m.bias.data.zero_()

這段程式碼的基本流程就是,先從self.modules()中遍歷每一層,然後判斷更曾屬於什麼型別,是否是Conv2d,是否是BatchNorm2d,是否是Linear的,然後根據不同型別的層,設定不同的權值初始化方法,例如Xavier,kaiming,normal_等等。kaiming也是MSRA初始化,是何愷明大佬在微軟亞洲研究院的時候,因此得名。

上面程式碼中用到了self.modules(),這個是什麼東西呢?

# self.modules的原始碼
def modules(self):
	for name,module in self.named_modules():
		yield module

功能就是:能依次返回模型中的各層,yield是讓一個函式可以像迭代器一樣可以用for迴圈不斷從裡面遍歷(可能說的不太明確)。

3 完整執行程式碼

我們用下面的例子來更深入的理解self.modules(),同時也把上面的內容都串起來(下面的程式碼塊可以執行):

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                # m.weight.data.normal_(0,0.01)
                m.bias.data.zero_()

net = Net()
net.initialize_weights()
print(net.modules())
for m in net.modules():
    print(m)

執行結果:

# 這個是print(net.modules())的輸出
<generator object Module.modules at 0x0000023BDCA23258>
# 這個是第一次從net.modules()取出來的東西,是整個網路的結構
Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
# 從net.modules()第二次開始取得東西就是每一層了
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Linear(in_features=400, out_features=120, bias=True)
Linear(in_features=120, out_features=84, bias=True)
Linear(in_features=84, out_features=10, bias=True)

其中呢,並不是每一層都有偏執bias的,有的卷積層可以設定成不要bias的,所以對於卷積網路引數的初始化,需要判斷一下是否有bias,(不過我好像記得bias預設初始化為0?不確定,有知道的朋友可以交流)

torch.nn.init.xavier_normal(m.weight.data)
if m.bias is not None:
	m.bias.data.zero_()

上面程式碼表示用xavier_normal方法對該層的weight初始化,並判斷是否存在偏執bias,若存在,將bias初始化為0。

4 尺寸計算與引數計算

我們把上面的主函式部分改成:

net = Net()
net.initialize_weights()
layers = {}
for m in net.modules():
    if isinstance(m,nn.Conv2d):
        print(m)
        break

這裡的輸出m就是:

Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))

這個卷積層,就是我們設定的第一個卷積層,含義就是:輸入3通道,輸出6通道,卷積核\(5\times 5\),步長1,padding=0.

【問題1:輸入特徵圖和輸出特徵圖的尺寸計算】

之前的文章也講過這個了,

\(output = \frac{input+2\times padding -kernel}{stride}+1\)

用程式碼來驗證一下這個公式:

net = Net()
net.initialize_weights()
input = torch.ones((16,3,10,10))
output = net.conv1(input)
print(input.shape)
print(output.shape)

初始結果:

torch.Size([16, 3, 10, 10])
torch.Size([16, 6, 6, 6])

第一個維度上batch,第二個是通道channel,第三個和第四個是圖片(特徵圖)的尺寸。

\(\frac{10+2\times 0-5}{1}+1=6\) 算出來的結果沒毛病。

【問題2:這個卷積層中有多少的引數?】
輸入通道是3通道的,輸出是6通道的,卷積核是\(5\times 5\)的,所以理解為6個\(3\times 5\times 5\)的卷積核,所以不考慮bias的話,引數量是\(3\times 5\times 5\times 6=450\),考慮bais的話,就每一個卷積核再增加一個偏置值。(這是一個一般人會忽略的知識點欸)

下面用程式碼來驗證:

net = Net()
net.initialize_weights()
for m in net.modules():
    if isinstance(m,nn.Conv2d):
        print(m)
        print(m.weight.shape)
        print(m.bias.shape)
        break

輸出結果是:

Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
torch.Size([6, 3, 5, 5])
torch.Size([6])

都和預料中一樣。

相關文章