ResNet(Residual Network,殘差網路)是深度學習領域中的重要突破之一,由 Kaiming He 等人在 2015 年提出。其核心思想是透過引入殘差連線(skip connections)來緩解深層網路中的梯度消失問題,使得網路可以更高效地訓練,同時顯著提升了深度網路的效能。
本文以一個 ResNet 的簡單實現為例,詳細解析其工作原理、程式碼結構和設計思想,並介紹 ResNet 的發展背景和改進版本。
背景與動機
隨著網路深度的增加,傳統深層神經網路面臨以下問題:
- 梯度消失與梯度爆炸: 在網路傳播過程中,梯度逐層衰減或爆炸,使得深層網路難以有效訓練。
- 退化問題: 增加網路深度並不一定帶來更高的準確率,反而可能導致訓練誤差增大。
為了應對這些挑戰,ResNet 提出了殘差學習框架,透過學習輸入與輸出之間的殘差來簡化最佳化過程。
殘差塊 (Residual Block)
設計思想
在 ResNet 中,一個基本的單元是殘差塊。假設希望擬合一個目標對映H(x),ResNet 將其重新表述為:
\[H(x) = F(x) + x
\]
其中:
- F(x) 是要學習的殘差函式。
- x 是輸入,直接透過快捷連線(shortcut connection)傳遞到輸出。
這種設計可以讓網路更容易最佳化,因為相比直接學習 H(x),學習 F(x)通常更容易。
程式碼實現
以下是一個標準的殘差塊實現:
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
# 當輸入和輸出維度不匹配時,新增一個卷積層以調整維度
if in_channels != out_channels or stride != 1:
self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
self.downsample_bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
residual = self.downsample_bn(residual)
out += residual
out = self.relu(out)
return out
核心部分解析:
- 卷積操作:
- 使用兩個3 \(\times\) 3的卷積核,提取特徵。
- 透過批歸一化 (BatchNorm) 穩定訓練。
- 殘差連線:
- 當輸入和輸出通道數一致時,直接加和。
- 若通道數或尺寸不同,則透過1 \(\times\) 1卷積調整形狀。
- 啟用函式:
- 使用 ReLU 函式,增加非線性。
ResNet 網路結構
ResNet 由多個殘差塊堆疊而成,不同版本的 ResNet 使用的塊數和通道數不同。以下是一個簡化的 ResNet 實現:
class ResNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
# 輸入影像尺寸為 28 x 28
self.block1 = ResBlock(3, 64)
# 輸出 28 x 28
self.block2 = ResBlock(64, 128, stride=2)
# 輸出 14 x 14
self.block3 = ResBlock(128, 256, stride=2)
# 輸出 7 x 7
self.block4 = ResBlock(256, 512, stride=2)
# 輸出 4 x 4
self.block5 = ResBlock(512, 1024, stride=2)
# 輸出 2 x 2
self.block6 = ResBlock(1024, 2048, stride=2)
# 輸出 1 x 1
self.fc = nn.Linear(2048, num_classes)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
網路結構說明:
- 輸入為 28 \(\times\) 28 的影像,透過 6 個殘差塊提取特徵。
- 每次透過殘差塊,通道數增加,空間尺寸減少一半。
- 最後透過全連線層實現分類。
ResNet 的優勢
- 解決梯度問題: 殘差連線使得梯度能夠直接傳遞到前層,有效緩解了梯度消失問題。
- 更深的網路: ResNet-50 和 ResNet-152 等深度版本大大提升了效能,廣泛用於影像分類、目標檢測等任務。
- 模組化設計: 殘差塊設計簡單,可擴充套件性強。
總結
本文透過程式碼實現和理論講解,深入解析了 ResNet 的核心思想和設計細節。ResNet 是深度學習領域的重要里程碑,其提出的殘差學習框架為訓練深層網路提供了有效的方法。隨著 ResNet 的不斷髮展,它在各種任務中依然表現強勁,是經典的深度學習模型之一。
透過理解 ResNet 的原理和實現,我們不僅可以靈活應用現有的網路架構,還能為創新和改進深度網路提供思路。