ResNet詳解:網路結構解讀與PyTorch實現教程

techlead_krischang發表於2023-10-29

本文深入探討了深度殘差網路(ResNet)的核心概念和架構組成。我們從深度學習和梯度消失問題入手,逐一解析了殘差塊、初始卷積層、殘差塊組、全域性平均池化和全連線層的作用和優點。文章還包含使用PyTorch構建和訓練ResNet模型的實戰部分,帶有詳細的程式碼和解釋。

關注TechLead,分享AI與雲服務技術的全維度知識。作者擁有10+年網際網路服務架構、AI產品研發經驗、團隊管理經驗,同濟本復旦碩,復旦機器人智慧實驗室成員,阿里雲認證的資深架構師,專案管理專業人士,上億營收AI產品研發負責人。

file

一、深度殘差網路(Deep Residual Networks)簡介

深度殘差網路(Deep Residual Networks,簡稱ResNet)自從2015年首次提出以來,就在深度學習領域產生了深遠影響。透過一種創新的“殘差學習”機制,ResNet成功地訓練了比以往模型更深的神經網路,從而顯著提高了多個任務的效能。深度殘差網路透過引入殘差學習和特殊的網路結構,解決了傳統深度神經網路中的梯度消失問題,並實現了高效、可擴充套件的深層模型。

深度學習與網路深度的挑戰

在深度學習中,網路的“深度”(即層數)通常與模型的能力成正比。然而,隨著網路深度的增加,一些問題也隨之出現,最突出的是梯度消失/爆炸問題。這使得深層網路難以訓練。

殘差學習的提出

file
傳統的深度神經網路試圖學習目標函式 ( H(x) ),但是在ResNet中,每個網路層實際上學習的是一個殘差函式 ( F(x) = H(x) - x )。然後,這個殘差結果與輸入 ( x ) 相加,形成 ( H(x) = F(x) + x )。這一機制使得網路更容易學習身份對映,進而緩解了梯度消失問題。

# PyTorch中的殘差塊實現
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
            
    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

# 輸出示例
x = torch.randn(64, 3, 32, 32)
block = ResidualBlock(3, 64)
out = block(x)
print(out.shape)  # 輸出:torch.Size([64, 64, 32, 32])

為什麼ResNet有效?

  • 解決梯度消失問題:透過殘差連線,梯度能夠更容易地反向傳播。
  • 引數效率:與傳統的深層網路相比,ResNet能以更少的引數實現更好的效能。

二、深度學習與梯度消失問題

在深入研究深度殘差網路(ResNet)之前,理解梯度消失問題是至關重要的。該問題長期以來一直是訓練深層神經網路的主要難點。本節將講解梯度消失問題的基本原理,以及這一問題如何影響深度學習模型的訓練。

梯度消失問題定義

梯度消失問題發生在神經網路的反向傳播過程中,具體表現為網路中某些權重的梯度接近或變為零。這導致這些權重幾乎不會更新,從而阻礙了網路的訓練。

數學上,假設我們有一個誤差函式 ( E ),對於網路中的某個權重 ( w ),如果 ( \frac{\partial E}{\partial w} ) 趨近於零,則表明出現了梯度消失問題。

為什麼會出現梯度消失?

啟用函式

使用Sigmoid或者Tanh等飽和啟用函式時,其導數在兩端極小,這很容易導致梯度消失。

初始化方法

權重初始化不當也可能導致梯度消失。例如,如果初始化權重過小,那麼啟用函式的輸出和梯度都可能非常小。

網路深度

網路越深,梯度在反向傳播過程中經過的層就越多,導致梯度消失問題更加嚴重。

如何解決梯度消失問題

  • 使用ReLU啟用函式:ReLU(Rectified Linear Unit)啟用函式能夠緩解梯度消失。
  • 合適的權重初始化:如He初始化或Glorot初始化。
  • 使用短接結構(Skip Connections):這是ResNet解決梯度消失問題的核心機制。
# 使用ReLU和He初始化的簡單示例
import torch.nn as nn

class SimpleNetwork(nn.Module):
    def __init__(self):
        super(SimpleNetwork, self).__init__()
        self.layer1 = nn.Linear(10, 50)
        nn.init.kaiming_normal_(self.layer1.weight, nonlinearity='relu')  # He初始化
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        return x

# 輸出示例
x = torch.randn(32, 10)
model = SimpleNetwork()
out = model(x)
print(out.shape)  # 輸出:torch.Size([32, 50])

三、殘差塊(Residual Blocks)基礎

殘差塊(Residual Blocks)是深度殘差網路(Deep Residual Networks,或ResNet)中的基本構建單元。透過使用殘差塊,ResNet有效地解決了梯度消失問題,並能訓練極深的網路。本節將深入探討殘差塊的基礎概念、設計與實現。殘差塊作為ResNet的基礎組成部分,其設計充分考慮了訓練穩定性和模型效能。透過引入殘差學習和短接連線,ResNet能夠有效地訓練深度網路,從而在多個任務上達到先進的效能。
file

殘差塊的核心思想

在傳統的卷積神經網路(CNN)中,每個卷積層試圖學習輸入與輸出之間的對映。殘差塊則採用了不同的策略:它們試圖學習輸入與輸出之間的殘差對映,即:

[
F(x) = H(x) - x
]

其中,( F(x) ) 是殘差函式,( H(x) ) 是目標對映函式,( x ) 是輸入。然後,( F(x) ) 與輸入 ( x ) 相加,得到最終輸出:

[
H(x) = F(x) + x
]

結構組成

一個基礎的殘差塊通常包含以下幾個部分:

  • 卷積層:用於特徵提取。
  • 批次歸一化(Batch Normalization):用於加速訓練和改善模型泛化。
  • 啟用函式:通常使用ReLU。
  • 短接連線(Skip Connection):直接連線輸入和輸出。
# 殘差塊的PyTorch實現
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
            
    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

# 輸出示例
x = torch.randn(64, 3, 32, 32)
block = ResidualBlock(3, 64)
out = block(x)
print(out.shape)  # 輸出:torch.Size([64, 64, 32, 32])

殘差塊的變體

  • Bottleneck Blocks:在更深的ResNet(如ResNet-152)中,為了減少計算量,通常使用“瓶頸”結構,即先透過一個小的卷積核(如1x1)降維,再進行3x3卷積,最後透過1x1卷積恢復維度。

四、ResNet架構

file

本節將介紹ResNet(深度殘差網路)的整體架構,以及它在計算機視覺和其他領域的應用。一個標準的ResNet模型由多個殘差塊組成,通常開始於一個普通的卷積層和池化層,用於進行初步的特徵提取。接下來是一系列的殘差塊,最後是全域性平均池化層和全連線層。

架構組成

  • 初始卷積層:用於初步特徵提取。
  • 殘差塊組(Residual Blocks Group):包含多個殘差塊。
  • 全域性平均池化(Global Average Pooling):減小維度。
  • 全連線層:用於分類或其他任務。

4.1 初始卷積層

file
在進入深度殘差網路的主體結構之前,第一層通常是一個初始卷積層。這個卷積層的主要任務是對輸入影像進行一定程度的空間下采樣(Spatial Downsampling)和特徵抽取。

功能和作用

  1. 空間下采樣(Spatial Downsampling): 初始卷積層通常具有較大的卷積核和步長(stride),用於減少後續層需要處理的空間維度,從而降低計算複雜度。
  2. 特徵抽取: 初始卷積層能夠抓取影像的基礎特徵,如邊緣、紋理等,為後續的特徵抽取工作打下基礎。

結構詳解

在ResNet-18和ResNet-34中,這一初始卷積層通常由一個7x7大小的卷積核、步長(stride)為2和填充(padding)為3組成。這個層後面通常還會跟隨一個批次歸一化(Batch Normalization)層和ReLU啟用函式。

self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)

為何不使用多個小卷積核?

在一些其他網路架構中,初始層可能由多個小卷積核(如3x3)組成,那麼為什麼ResNet要使用一個較大的7x7卷積核呢?主要原因是,一個大的卷積核可以在相同數量的引數下,提供更大的感受野(Receptive Field),從而更有效地捕獲影像的全域性資訊。

小結

初始卷積層在整個ResNet架構中扮演著非常重要的角色。它不僅完成了對輸入影像的基礎特徵抽取,還透過空間下采樣減輕了後續計算的負擔。這些設計細節共同使得ResNet能在保持高效能的同時,具有更低的計算複雜度。

4.2 殘差塊組(Residual Block Groups)

file
在初始卷積層之後,緊接著就是ResNet的核心組成部分,也就是殘差塊組(Residual Block Groups)。這些殘差塊組成了ResNet架構中的主體,負責高階特徵的抽取和傳遞。

功能和作用

  1. 特徵抽取: 每個殘差塊組負責從其前一組中提取的特徵中提取更高階的特徵。
  2. 非線性效能增強: 透過殘差連結,每個殘差塊組能夠學習輸入與輸出之間的複雜非線性對映。
  3. 避免梯度消失和爆炸: 殘差塊組內的Skip Connection(跳過連線)能夠更好地傳遞梯度,有助於訓練更深的網路。

結構詳解

在標準的ResNet-18或ResNet-34模型中,通常會包括幾組殘差塊。每一組都有一定數量的殘差塊,這些塊的數量和組的深度有關。

  • 第一組可能包括2個殘差塊,用64個輸出通道。
  • 第二組可能包括2個殘差塊,用128個輸出通道。
  • 第三組可能包括2個殘差塊,用256個輸出通道。
  • 第四組可能包括2個殘差塊,用512個輸出通道。
# 示例程式碼,表示第一組殘差塊
self.layer1 = nn.Sequential(
    ResidualBlock(64, 64),
    ResidualBlock(64, 64)
)

殘差塊組與特徵圖大小

每一組的第一個殘差塊通常會減小特徵圖的尺寸(即進行下采樣),而增加輸出通道數。這樣做可以保證模型的計算效率,同時能抓住更多層次的特徵。

小結

殘差塊組是ResNet架構中最核心的部分,透過逐層抽取更高階的特徵並透過殘差連線最佳化梯度流動,這些設計使得ResNet模型能夠有效並且準確地進行影像分類以及其他計算機視覺任務。

4.3 全域性平均池化(Global Average Pooling)

file
在透過一系列殘差塊組進行特徵抽取和非線性對映之後,ResNet通常使用全域性平均池化層(Global Average Pooling,簡稱GAP)作為網路的最後一個卷積層。與傳統的全連線層相比,全域性平均池化有幾個顯著優點。

功能和作用

  1. 降維: 全域性平均池化層將每個特徵圖(Feature Map)縮減為一個單一的數值,從而顯著減小模型引數和計算量。
  2. 防止過擬合: 由於其簡單性和少量的引數,全域性平均池化有助於防止模型過擬合。
  3. 改善泛化能力: 簡化的網路結構能更好地泛化到未見過的資料。

結構詳解

全域性平均池化層簡單地計算每個特徵圖的平均值。假設我們有一個形狀為(batch_size, num_channels, height, width)的特徵圖,全域性平均池化將輸出一個形狀為(batch_size, num_channels)的張量。

# PyTorch中的全域性平均池化
self.global_avg_pooling = nn.AdaptiveAvgPool2d((1, 1))

與全連線層的比較

在許多傳統的卷積神經網路(如AlexNet)中,網路的末端通常包括幾個全連線層。然而,全連線層往往包含大量的引數,從而增加了過擬合的風險。與之相比,全域性平均池化由於其引數更少、計算更簡單,因此更受現代深度學習架構的青睞。

小結

全域性平均池化是ResNet架構的一個重要組成部分,它不僅顯著減小了模型的引數數量,還有助於提高模型的泛化能力。這些優點使得全域性平均池化在許多現代卷積神經網路中都有廣泛的應用。

4.4 全連線層(Fully Connected Layer)

file
在全域性平均池化(GAP)之後,ResNet架構通常包含一個或多個全連線層(Fully Connected Layer)。全連線層在ResNet中的主要目的是為了進行分類或者回歸任務。

功能和作用

  1. 分類或迴歸: 全連線層的主要任務是根據前層特徵進行分類或迴歸。
  2. 增加模型複雜度: 相比GAP,全連線層可以增加模型的複雜度,從而擬合更復雜的函式。
  3. 特徵整合: 全連線層能夠整合前面各層的資訊,輸出一個固定大小的特徵向量。

結構詳解

全連線層通常接收全域性平均池化層輸出的平坦化(flattened)向量,並透過一系列線性變換與啟用函式生成輸出。例如,在分類問題中,全連線層通常輸出一個與類別數相等的節點。

# PyTorch中的全連線層示例
self.fc = nn.Linear(512, num_classes)  # 假設全域性平均池化後有512個通道,num_classes為分類數量

啟用函式與Dropout

全連線層之後通常會接一個啟用函式,如ReLU或者Softmax,以引入非線性。有時也會使用Dropout層來防止過擬合,尤其是在全連線層的節點數較多時。

小結

雖然全連線層相對簡單,但它在ResNet以及其他深度學習模型中佔據重要地位。全連線層是進行分類或迴歸的關鍵,同時也為模型提供了最後的機會進行特徵整合和學習複雜對映。


五、實戰:使用PyTorch構建ResNet模型

5.1 構建ResNet模型

在這一部分中,我們將使用PyTorch框架來實現一個簡化版的ResNet-18模型。我們的目標是構建一個可以在CIFAR-10資料集上進行分類任務的模型。

前置條件

確保您已經安裝了PyTorch和其他必要的庫。

pip install torch torchvision

構建Residual Block

首先,讓我們實現一個殘差塊。這是前面章節已經介紹過的內容。

import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

構建ResNet-18

接下來,我們使用殘差塊來構建完整的ResNet-18模型。

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

以上程式碼定義了一個用於CIFAR-10分類任務的ResNet-18模型。在這個模型中,我們使用了前面定義的ResidualBlock類,並透過_make_layer函式來堆疊多個殘差塊。

模型測試

接下來,我們可以測試這個模型以確保其結構是正確的。

# 建立一個模擬輸入
x = torch.randn(64, 3, 32, 32)

# 例項化模型
model = ResNet18(num_classes=10)

# 前向傳播
output = model(x)

# 輸出形狀應為(64, 10),因為我們有64個樣本和10個類別
print(output.shape)  # 輸出:torch.Size([64, 10])

5.2 訓練與評估

在成功構建了ResNet-18模型之後,下一步就是進行模型的訓練和評估。在這一部分,我們將介紹如何在CIFAR-10資料集上完成這兩個步驟。

資料預處理與載入

首先,我們需要準備資料。使用PyTorch的torchvision庫,我們可以非常方便地下載和預處理CIFAR-10資料集。

import torch
import torchvision
import torchvision.transforms as transforms

# 資料預處理
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 載入資料集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

模型訓練

訓練模型通常需要指定損失函式和最佳化器,並反覆進行前向傳播、計算損失、反向傳播和引數更新。

import torch.optim as optim

# 例項化模型並移至GPU
model = ResNet18(num_classes=10).cuda()

# 定義損失函式和最佳化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

# 訓練模型
for epoch in range(10):  # 執行10個週期
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()

        # 清零梯度快取
        optimizer.zero_grad()

        # 前向傳播,計算損失,反向傳播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        # 更新引數
        optimizer.step()

模型評估

訓練完成後,我們需要評估模型的效能。這通常透過在測試集上計算模型的準確率來完成。

# 切換模型為評估模式
model.eval()

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.cuda(), labels.cuda()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

六、總結

透過深入探討ResNet的關鍵組成部分,包括深度殘差網路、梯度消失問題、殘差塊、初始卷積層、殘差塊組、全域性平均池化以及全連線層,我們不僅理解了其背後的設計思想和優勢,還透過PyTorch實現了一個完整的ResNet模型並進行了訓練與評估。ResNet透過其獨特的殘差連線有效地解決了深度網路中的梯度消失問題,並且在多項視覺任務中實現了突破性的效能。這些優點使得ResNet成為現代深度學習架構中不可或缺的一部分。

關注TechLead,分享AI與雲服務技術的全維度知識。作者擁有10+年網際網路服務架構、AI產品研發經驗、團隊管理經驗,同濟本復旦碩,復旦機器人智慧實驗室成員,阿里雲認證的資深架構師,專案管理專業人士,上億營收AI產品研發負責人。
如有幫助,請多關注
TeahLead KrisChang,10+年的網際網路和人工智慧從業經驗,10年+技術和業務團隊管理經驗,同濟軟體工程本科,復旦工程管理碩士,阿里雲認證雲服務資深架構師,上億營收AI產品業務負責人。

相關文章