深入解析 PyTorch 的 BatchNorm2d:原理與實現

crazypigf發表於2024-11-20

在深度學習中,Batch Normalization 是一種常用的技術,用於加速網路訓練並穩定模型收斂。本文將結合一個具體程式碼例項,詳細解析 PyTorch 中 BatchNorm2d 的實現原理,同時透過手動計算驗證其計算過程,幫助大家更直觀地理解 BatchNorm 的工作機制。


1. Batch Normalization 的基本原理

1.1 什麼是 Batch Normalization?

Batch Normalization (BN) 是由 Sergey Ioffe 和 Christian Szegedy 在 2015 年提出的一種正則化方法。其主要目的是解決深度神經網路訓練中因輸入資料分佈不一致(即 內部協變數偏移)而導致的訓練困難問題。BN 的核心思想是對每一批資料的特徵進行標準化,具體包括以下步驟:

  1. 計算每個特徵的均值與方差: 對輸入特徵 x計算均值 \(\mu\) 和方差 \(\sigma^2\)

    \[\mu = \frac{1}{N} \sum_{i=1}^N x_i, \quad \sigma^2 = \frac{1}{N} \sum_{i=1}^N (x_i - \mu)^2 \]

    其中 N 是當前批次的樣本總數。

  2. 對特徵進行標準化:

    \[\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \]

    這裡,\(\epsilon\) 是一個很小的數,用於防止分母為零。

  3. 引入可學習的仿射變換:

    \[y = \gamma \hat{x} + \beta \]

    其中,\(\gamma\)\(\beta\) 是可學習的引數,用於恢復模型的表達能力。

1.2 為什麼使用 BatchNorm?

  • 加速收斂: 透過標準化減少內部協變數偏移,使得啟用值在訓練過程中分佈更加穩定,從而加快收斂速度。
  • 正則化效果: BN 在一定程度上起到正則化作用,可以減少對 Dropout 等正則化技術的依賴。
  • 更高的學習率: 由於 BN 能緩解梯度爆炸或消失的問題,允許使用更高的學習率。

2. PyTorch 中 BatchNorm2d 的實現

在 PyTorch 中,BatchNorm2d 是專為 4D 輸入(即二維卷積層的輸出)設計的批歸一化操作。其計算流程如下:

  1. 輸入維度: 假設輸入的維度為 (N, C, H, W),其中:
    • N 是 batch size;
    • C 是通道數;
    • H,W 是特徵圖的高度和寬度。
  2. 統計均值和方差: 對每個通道 C 分別計算均值和方差,統計維度為 [0, 2, 3](即對 batch size 和空間維度進行平均)。
  3. 標準化和仿射變換: 按公式 \(y = \gamma \hat{x} + \beta\) 計算輸出。

3. 程式碼解析與實現

3.1 示例程式碼

以下是一個完整的程式碼示例:

import torch
import torch.nn as nn

# 設定隨機種子,保證結果可復現
torch.manual_seed(1107)

# 建立一個 4D 張量,形狀為 (2, 3, 4, 4)
x = torch.rand(2, 3, 4, 4)

# 例項化 BatchNorm2d,通道數為 3,momentum 設定為 1
m = nn.BatchNorm2d(3, momentum=1)
y = m(x)

# 手動計算 BatchNorm2d
x_mean = x.mean(dim=[0, 2, 3], keepdim=True)  # 按通道計算均值
x_var = x.var(dim=[0, 2, 3], keepdim=True, unbiased=False)  # 按通道計算方差(無偏)
eps = m.eps  # 獲取 epsilon 值

y_manual = (x - x_mean) / ((x_var + eps).sqrt())  # 標準化公式

# 檢查兩種方法的輸出是否一致
print("使用 BatchNorm2d 的結果:", y)
print("手動計算的結果:", y_manual)
print("結果是否一致:", torch.allclose(y, y_manual, atol=1e-6))

3.2 輸出結果

執行上述程式碼,輸出如下:

使用 BatchNorm2d 的結果: tensor([[[[ 1.2311,  0.5357,  ...],
手動計算的結果: tensor([[[[ 1.2311,  0.5357,  ...],
結果是否一致: True

可以看到,BatchNorm2d 和手動計算的結果完全一致,這說明我們對其計算過程的推導是正確的。


4. 程式碼逐步解析

4.1 建立隨機輸入資料

x = torch.rand(2, 3, 4, 4)

這裡建立了一個形狀為 (2, 3, 4, 4) 的 4D 張量,模擬卷積層的輸出,其中:

  • Batch size N=2;
  • 通道數 C=3;
  • 特徵圖大小 H=4,W=4。

4.2 BatchNorm2d 的初始化

m = nn.BatchNorm2d(3, momentum=1)
  • 通道數: 3,對應輸入資料的通道數。
  • 動量: momentum=1 表示在每個批次中完全依賴當前批次的統計值,而不平滑更新。

4.3 手動計算均值與方差

x_mean = x.mean(dim=[0, 2, 3], keepdim=True)
x_var = x.var(dim=[0, 2, 3], keepdim=True, unbiased=False)
  • dim=[0, 2, 3] 指定計算均值和方差的維度,即跨 batch 和空間維度。
  • keepdim=True 保留原始維度,便於後續廣播操作。
  • unbiased=False 關閉無偏估計,與 BatchNorm 的預設設定一致。

4.4 手動計算標準化結果

y_manual = (x - x_mean) / ((x_var + eps).sqrt())

這裡實現了標準化公式:

\(\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}\)


4.5 驗證結果一致性

torch.allclose(y, y_manual, atol=1e-6)

使用 torch.allclose 驗證兩者是否一致,允許的誤差範圍由 atol=1e-6 指定。


5. 總結與思考

透過上述分析與程式碼實現,我們可以更直觀地理解 PyTorch 中 BatchNorm2d 的工作原理。總結如下:

  1. BatchNorm 的核心操作是標準化與仿射變換。
  2. PyTorch 的實現細節非常最佳化,支援多維資料的高效處理。
  3. 手動實現 BatchNorm 可以幫助我們驗證模型行為,並在自定義層中實現類似功能。

思考: 在實際應用中,BatchNorm 的效果與 batch size 有很大關係,小 batch size 時可能導致統計量不穩定,建議結合 Group Normalization 等替代方法使用。

相關文章