LeNet-5網路搭建詳解

追風趕月的少年發表於2021-06-20

LeNet-5是由Yann LeCun設計的用於手寫數字識別和機器列印字元的卷積神經網路。她在1998年發表的論文《基於梯度學習的文字識別》中提出了該模型,並給出了對該模型網路架構的介紹。如下圖所示,LeNet-5共有7層(不包括輸入層),包含卷積層、下采樣層、全連線層,而其輸入影像為32*32.論文連結:Gradient-based learning applied to document recognition | IEEE Journals & Magazine | IEEE Xplore

 圖1. LeNet-5網路架構

1.C1:卷積層

c1層採用卷積層對輸入的影像進行特徵提取,利用6個5*5的卷積核生成6個特徵圖(feature map)。其步長為1且不使用擴充值。因此卷積後的特徵層為28*28.一個卷積核擁有的可訓練引數為5*5+1=26,其中1為偏置引數。整個C1層可訓練引數為(5*5+1)*6=156.

2.S2:下采樣層

下采樣(subsampling)層主要對特徵進行降維處理,效果與池化相同。S2層使用2*2的濾波器池化C1的特徵圖,因此將生成6個尺寸為14*14的特徵圖。在計算時,將濾波器中的4個值相加,然後乘以可訓練權值引數w,加上偏置引數b,最後通過sigmoid函式形成新的值。S2層的每個特徵圖中都有兩個引數,一個是權值引數,一個是偏置引數,因此該層共有2*6=12個引數。

3.C3:卷積層

C3層有16個大小為5*5的卷積核,步長為1且不填充邊界。C3層將S2層6個14*14的特徵圖卷積成16個10*10的特徵圖。值得注意的是,S2層與C3層的卷積核並不是全連線的,而是部分連線的。

 

 圖2:S2層特徵圖與C3層卷積核連線的組合

4.S4:下采樣層

S4的濾波器與S2層的濾波器相似,也是2*2的,所以,S4層的特徵圖池化後,將生成16個5*5的特徵圖。S4層引數的個數為2*16=32.

5.C5:卷積層

C5層有120個5*5的卷積核,將產生120個1*1的特徵圖,與S4層是全連線的。C5層引數的個數不能參照C1層來計算,而是要參照C3層來計算,且此時是沒有組合的,因此,應該是(5*5*16+1)*120=48120.

6.F6:全連線層

F6有84個單元,單元的個數與輸出層的設計有關。該層作為典型的神經網路層,每一個單元都計算輸入向量與權值引數的點積並加上偏置引數,然後傳給sigmoid函式,產生該單元的一個狀態並傳遞給輸出層。在這裡,將輸出作為輸出層的徑向基函式的初始引數,用於識別完整的ASCII字符集。C5有120個單元;F6層有84個單元,每個單元都將容納120個單元的計算結果。因此,F6層引數的個數為(120+1)*84=10164.

7.output:輸出層

 

output層是全連線層,共有10個單元,代表數字0~9。利用徑向基函式,將F6層84個單元的輸出作為節點的輸入xj,計算歐氏距離。距離越近,結果就越小,意味著識別的樣本越符合該節點所代表的字元。由於該層是全連線層,引數個數為84*10=840。

 

 網路搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        #input image channel is one, output channels is six,5*5 square convolution
        self.conv1=nn.Conv2d(1, 6, 5)
        self.conv2=nn.Conv2d(6, 16, 5)
        self.fc1=nn.Linear(16*5*5, 120)
        self.fc2=nn.Linear(120, 84)
        self.fc3=nn.Linear(84, 10)
    
    def forward(self, x):
        #max pooling over a (2,2) window
        #c1
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        #if the kernel size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        #all dimensions except the batch dimension
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

net = Net()
print(net)

 

相關文章