神經網路基礎部件-BN層詳解

嵌入式視覺發表於2023-02-10

一,數學基礎

1.1,機率密度函式

隨機變數(random variable)是可以隨機地取不同值的變數。隨機變數可以是離散的或者連續的。簡單起見,本文用大寫字母 \(X\) 表示隨機變數,小寫字母 \(x\) 表示隨機變數能夠取到的值。例如,\(x_1\)\(x_2\) 都是隨機變數 \(X\) 可能的取值。隨機變數必須伴隨著一個機率分佈來指定每個狀態的可能性。

機率分佈(probability distribution)用來描述隨機變數或一簇隨機變數在每一個可能取到的狀態的可能性大小。我們描述機率分佈的方式取決於隨機變數是離散的還是連續的。

當我們研究的物件是連續型隨機變數時,我們用機率密度函式(probability density function, PDF)而不是機率質量函式來描述它的機率分佈。

更多內容請閱讀《花書》第三章-機率與資訊理論,或者我的文章-深度學習數學基礎-機率與資訊理論

1.2,正態分佈

當我們不知道資料真實分佈時使用正態分佈的原因之一是,正態分佈擁有最大的熵,我們透過這個假設來施加儘可能少的結構。

實數上最常用的分佈就是正態分佈(normal distribution),也稱為高斯分佈 (Gaussian distribution)。

如果隨機變數 \(X\) ,服從位置引數為 \(\mu\)、尺度引數為 \(\sigma\) 的機率分佈,且其機率密度函式為:

\[f(x)=\frac{1}{\sigma\sqrt{2 \pi} } e^{- \frac{{(x-\mu)^2}}{2\sigma^2}} \tag{1} \]

則這個隨機變數就稱為正態隨機變數,正態隨機變數服從的機率分佈就稱為正態分佈,記作:

\[X \sim N(\mu,\sigma^2) \tag{2} \]

如果位置引數 \(\mu = 0\),尺度引數 \(\sigma = 1\) 時,則稱為標準正態分佈,記作:

\[X \sim N(0, 1) \tag{3} \]

此時,機率密度函式公式簡化為:

\[f(x)=\frac{1}{\sqrt{2 \pi}} e^{- \frac{x^2}{2}} \tag{4} \]

正太分佈的數學期望值或期望值 \(\mu\) 等於位置引數,決定了分佈的位置;其方差 \(\sigma^2\) 的開平方或標準差 \(\sigma\) 等於尺度引數,決定了分佈的幅度。正太分佈的機率密度函式曲線呈鐘形,常稱之為鐘形曲線,如下圖所示:

正太分佈機率密度函式曲線

可認為構造正太分佈函式,也可透過 np.random.normal 函式生成指定均值和標準差的正態分佈隨機數,然後基於 matplotlib + seabornkdeplot函式繪製機率密度曲線。示例程式碼如下所示:

import seaborn as sns
x1 = np.random.normal(0, 1, 100)
x2 = np.random.normal(0, 1.5, 100) 
x3 = np.random.normal(2, 1.5, 100) 

plt.figure(dpi = 200)

sns.kdeplot(x1, label="μ=0, σ=1")
sns.kdeplot(x2, label="μ=0, σ=1.5")
sns.kdeplot(x3, label="μ=2, σ=2.5")

#顯示圖例
plt.legend()
#新增標題
plt.title("Normal distribution")
plt.show()

以上程式碼直接執行後,輸出結果如下圖:

不同引數的正態分佈函式曲線

當然也可以自己實現正太分佈的機率密度函式,程式碼和程式輸出結果如下:

import numpy as np
import matplotlib.pyplot as plt
plt.figure(dpi = 200)
plt.style.use('seaborn-darkgrid') # 主題設定

def nd_func(x, sigma, mu):
  	"""自定義實現正太分佈的機率密度函式
  	"""
    a = - (x-mu)**2 / (2*sigma*sigma)
    f = np.exp(a) / (sigma * np.sqrt(2*np.pi))
    return f

if __name__ == '__main__':
    x = np.linspace(-5, 5)
    f = nd_fun(x, 1, 0)
    p1, = plt.plot(x, f)

    f = nd_fun(x, 1.5, 0)
    p2, = plt.plot(x, f)

    f = nd_fun(x, 1.5, 2)
    p3, = plt.plot(x, f)

    plt.legend([p1 ,p2, p3], ["μ=0,σ=1", "μ=0,σ=1.5", "μ=2,σ=1.5"])
    plt.show()

自己實現的不同引數的正態分佈函式曲線

二,背景

訓練深度神經網路的複雜性在於,因為前面的層的引數會發生變化導致每層輸入的分佈在訓練過程中會發生變化。這又導致模型需要需要較低的學習率和非常謹慎的引數初始化策略,從而減慢了訓練速度,並且具有飽和非線性的模型訓練起來也非常困難。

網路層輸入資料分佈發生變化的這種現象稱為內部協變數轉移,BN 就是來解決這個問題。

2.1,如何理解 Internal Covariate Shift

在深度神經網路訓練的過程中,由於網路中引數變化而引起網路中間層資料分佈發生變化的這一過程被稱在論文中稱之為內部協變數偏移(Internal Covariate Shift)。

那麼,為什麼網路中間層資料分佈會發生變化呢?

在深度神經網路中,我們可以將每一層視為對輸入的訊號做了一次變換(暫時不考慮啟用,因為啟用函式不會改變輸入資料的分佈):

\[Z = W \cdot X + B \tag{5} \]

其中 \(W\)\(B\) 是模型學習的引數,這個公式涵蓋了全連線層和卷積層。

隨著 SGD 演算法更新引數,和網路的每一層的輸入資料經過公式5的運算後,其 \(Z\)分佈一直在變化,因此網路的每一層都需要不斷適應新的分佈,這一過程就被叫做 Internal Covariate Shift。

而深度神經網路訓練的複雜性在於每層的輸入受到前面所有層的引數的影響—因此當網路變得更深時,網路引數的微小變化就會被放大。

2.2,Internal Covariate Shift 帶來的問題

  1. 網路層需要不斷適應新的分佈,導致網路學習速度的降低

  2. 網路層輸入資料容易陷入到非線性的飽和狀態並減慢網路收斂,這個影響隨著網路深度的增加而放大。

    隨著網路層的加深,後面網路輸入 \(x\) 越來越大,而如果我們又採用 Sigmoid 型啟用函式,那麼每層的輸入很容易移動到非線性飽和區域,此時梯度會變得很小甚至接近於 \(0\),導致引數的更新速度就會減慢,進而又會放慢網路的收斂速度。

飽和問題和由此產生的梯度消失通常透過使用修正線性單元啟用(ReLU(x)=max(x,0)$),更好的引數初始化方法和小的學習率來解決。然而,如果我們能保證非線性輸入的分佈在網路訓練時保持更穩定,那麼最佳化器將不太可能陷入飽和狀態,進而訓練也將加速。

2.3,減少 Internal Covariate Shift 的一些嘗試

  1. 白化(Whitening): 即輸入線性變換為具有零均值和單位方差,並去相關。

    白化過程由於改變了網路每一層的分佈,因而改變了網路層中本身資料的表達能力。底層網路學習到的引數資訊會被白化操作丟失掉,而且白化計算成本也高。

  2. 標準化(normalization)

    Normalization 操作雖然緩解了 ICS 問題,讓每一層網路的輸入資料分佈都變得穩定,但卻導致了資料表達能力的缺失。

三,批次歸一化(BN)

3.1,BN 的前向計算

論文中給出的 Batch Normalizing Transform 演算法計算過程如下圖所示。其中輸入是一個考慮一個大小為 \(m\) 的小批次資料 \(\cal B\)

Batch Normalizing Transform

論文中的公式不太清晰,下面我給出更為清晰的 Batch Normalizing Transform 演算法計算過程。

\(m\) 表示 batch_size 的大小,\(n\) 表示 features 數量,即樣本特徵值數量。在訓練過程中,針對每一個 batch 資料,BN 過程進行的操作是,將這組資料 normalization,之後對其進行線性變換,具體演算法步驟如下:

\[\begin{align} \mu_B &= \frac{1}{m}\sum_1^m x_i \tag{6} \\ \sigma^2_B &= \frac{1}{m} \sum_1^m (x_i-\mu_B)^2 \tag{7} \\ n_i &= \frac{x_i-\mu_B}{\sqrt{\sigma^2_B + \epsilon}} \tag{8} \\ z_i &= \gamma n_i + \beta = \frac{\gamma}{\sqrt{\sigma^2_B + \epsilon}}x_i + (\beta - \frac{\gamma\mu_{B}}{\sqrt{\sigma^2_B + \epsilon}})\tag{9} \\ \end{align} \]

以上公式乘法都為元素乘,即 element wise 的乘法。其中,引數 \(\gamma,\beta\) 是訓練出來的, \(\epsilon\) 是為零防止 \(\sigma_B^2\)\(0\) ,加的一個很小的數值,通常為1e-5。公式各個符號解釋如下:

符號 資料型別 資料形狀
\(X\) 輸入資料矩陣 [m, n]
\(x_i\) 輸入資料第i個樣本 [1, n]
\(N\) 經過歸一化的資料矩陣 [m, n]
\(n_i\) 經過歸一化的單樣本 [1, n]
\(\mu_B\) 批資料均值 [1, n]
\(\sigma^2_B\) 批資料方差 [1, n]
\(m\) 批樣本數量 [1]
\(\gamma\) 線性變換引數 [1, n]
\(\beta\) 線性變換引數 [1, n]
\(Z\) 線性變換後的矩陣 [1, n]
\(z_i\) 線性變換後的單樣本 [1, n]
\(\delta\) 反向傳入的誤差 [m, n]

其中:

\[z_i = \gamma n_i + \beta = \frac{\gamma}{\sqrt{\sigma^2_B + \epsilon}}x_i + (\beta - \frac{\gamma\mu_{B}}{\sqrt{\sigma^2_B + \epsilon}}) \nonumber \]

可以看出 BN 本質上是做線性變換。

3.2,BN 層如何工作

在論文中,訓練一個帶 BN 層的網路, BN 演算法步驟如下圖所示:

Training a Batch-Normalized Network

在訓練期間,我們一次向網路提供一小批資料。在前向傳播過程中,網路的每一層都處理該小批次資料。 BN 網路層按如下方式執行前向傳播計算:

Batch Norm 層執行的前向計算

圖片來源這裡

注意,圖中計算均值與方差的無偏估計方法是吳恩達在 Coursera 上的 Deep Learning 課程上提出的方法:對 train 階段每個 batch 計算的 mean/variance 採用指數加權平均來得到 test 階段 mean/variance 的估計。

在訓練期間,它只是計算此 EMA,但不對其執行任何操作。在訓練結束時,它只是將該值儲存為層狀態的一部分,以供在推理階段使用。

如下圖可以展示BN 層的前向傳播計算過程資料的 shape ,紅色框出來的單個樣本都指代單個矩陣,即運算都是在單個矩陣運算中計算的。

Batch Norm 向量的形狀

圖片來源 這裡

BN 的反向傳播過程中,會更新 BN 層中的所有 \(\beta\)\(\gamma\) 引數。

3.3,訓練和推理式的 BN 層

批次歸一化(batch normalization)的“批次”兩個字,表示在模型的迭代訓練過程中,BN 首先計算小批次( mini-batch,如 32)的均值和方差。但是,在推理過程中,我們只有一個樣本,而不是一個小批次。在這種情況下,我們該如何獲得均值和方差呢?

第一種方法是,使用的均值和方差資料是在訓練過程中樣本值的平均,即:

\[\begin{align} E[x] &= E[\mu_B] \nonumber \\ Var[x] &= \frac{m}{m-1} E[\sigma^2_B] \nonumber \\ \end{align} \]

這種做法會把所有訓練批次的 \(\mu\)\(\sigma\) 都儲存下來,然後在最後訓練完成時(或做測試時)做下平均。

第二種方法是使用類似動量的方法,訓練時,加權平均每個批次的值,權值 \(\alpha\) 可以為0.9:

\[\begin{align} \mu_{mov_{i}} &= \alpha \cdot \mu_{mov_{i}} + (1-\alpha) \cdot \mu_i \nonumber \\ \sigma_{mov_{i}} &= \alpha \cdot \sigma_{mov_{i}} + (1-\alpha) \cdot \sigma_i \nonumber \\ \end{align} \]

推理或測試時,直接使用模型檔案中儲存的 \(\mu_{mov_{i}}\)\(\sigma_{mov_{i}}\) 的值即可。

3.4,實驗

BNImageNet 分類資料集上實驗結果是 SOTA 的,如下表所示:

實驗結果表4

3.5,BN 層的優點

  1. BN 使得網路中每層輸入資料的分佈相對穩定,加速模型訓練和收斂速度

  2. 批標準化可以提高學習率。在傳統的深度網路中,學習率過高可能會導致梯度爆炸或梯度消失,以及陷入差的區域性最小值。批標準化有助於解決這些問題。透過標準化整個網路的啟用值,它可以防止層引數的微小變化隨著資料在深度網路中的傳播而放大。例如,這使 sigmoid 非線性更容易保持在它們的非飽和狀態,這對訓練深度 sigmoid 網路至關重要,但在傳統上很難實現。

  3. BN 允許網路使用飽和非線性啟用函式(如 sigmoid,tanh 等)進行訓練,其能緩解梯度消失問題

  4. 不需要 dropoutLRN(Local Response Normalization)層來實現正則化。批標準化提供了類似丟棄的正則化收益,因為透過實驗可以觀察到訓練樣本的啟用受到同一小批次樣例隨機選擇的影響。

  5. 減少對引數初始化方法的依賴

參考資料

  1. 維基百科-正態分佈
  2. Batch Norm Explained Visually — How it works, and why neural networks need it
  3. [15.5 批次歸一化的原理])(https://microsoft.github.io/ai-edu/基礎教程/A2-神經網路基本原理/第7步 - 深度神經網路/15.5-批次歸一化的原理.html)
  4. Batch Normalization原理與實戰

相關文章