文章來自:公眾號【機器學習煉丹術】。求關注~
其實關於BN層,我在之前的文章“梯度爆炸”那一篇中已經涉及到了,但是鑑於面試經歷中多次問道這個,這裡再做一個更加全面的講解。
Internal Covariate Shift(ICS)
Batch Normalization的原論文作者給了Internal Covariate Shift一個較規範的定義:在深層網路訓練的過程中,由於網路中引數變化而引起內部結點資料分佈發生變化的這一過程被稱作Internal Covariate Shift。
這裡做一個簡單的數學定義,對於全連結網路而言,第i層的數學表達可以體現為:
\(Z^i=W^i\times input^i+b^i\)
\(input^{i+1}=g^i(Z^i)\)
- 第一個公式就是一個簡單的線性變換;
- 第二個公式是表示一個啟用函式的過程。
【怎麼理解ICS問題】
我們知道,隨著梯度下降的進行,每一層的引數\(W^i,b^i\)都會不斷地更新,這意味著\(Z^i\)的分佈也不斷地改變,從而\(input^{i+1}\)的分佈發生了改變。這意味著,除了第一層的輸入資料不改變,之後所有層的輸入資料的分佈都會隨著模型引數的更新發生改變,而每一層就要不停的去適應這種資料分佈的變化,這個過程就是Internal Covariate Shift。
BN解決的問題
【ICS帶來的收斂速度慢】
因為每一層的引數不斷髮生變化,從而每一層的計算結果的分佈發生變化,後層網路不斷地適應這種分佈變化,這個時候會讓整個網路的學習速度過慢。
【梯度飽和問題】
因為神經網路中經常會採用sigmoid,tanh這樣的飽和啟用函式(saturated actication function),因此模型訓練有陷入梯度飽和區的風險。解決這樣的梯度飽和問題有兩個思路:第一種就是更為非飽和性啟用函式,例如線性整流函式ReLU可以在一定程度上解決訓練進入梯度飽和區的問題。另一種思路是,我們可以讓啟用函式的輸入分佈保持在一個穩定狀態來儘可能避免它們陷入梯度飽和區,這也就是Normalization的思路。
Batch Normalization
batchNormalization就像是名字一樣,對一個batch的資料進行normalization。
現在假設一個batch有3個資料,每個資料有兩個特徵:(1,2),(2,3),(0,1)
如果做一個簡單的normalization,那麼就是計算均值和方差,把資料減去均值除以標準差,變成0均值1方差的標準形式。
對於第一個特徵來說:
\(\mu=\frac{1}{3}(1+2+0)=1\)
\(\sigma^2=\frac{1}{3}((1-1)^2+(2-1)^2+(0-1)^2)=0.67\)
【通用公式】
\(\mu=\frac{1}{m}\sum_{i=1}^m{Z}\)
\(\sigma^2=\frac{1}{m}\sum_{i=1}^m(Z-\mu)\)
\(\hat{Z}=\frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}\)
- 其中m表示一個batch的數量。
- \(\epsilon\)是一個極小數,防止分母為0。
目前為止,我們做到了讓每個特徵的分佈均值為0,方差為1。這樣分佈都一樣,一定不會有ICS問題
如同上面提到的,Normalization操作我們雖然緩解了ICS問題,讓每一層網路的輸入資料分佈都變得穩定,但卻導致了資料表達能力的缺失。每一層的分佈都相同,所有任務的資料分佈都相同,模型學啥呢
【0均值1方差資料的弊端】
- 資料表達能力的缺失;
- 通過讓每一層的輸入分佈均值為0,方差為1,會使得輸入在經過sigmoid或tanh啟用函式時,容易陷入非線性啟用函式的線性區域。(線性區域和飽和區域都不理想,最好是非線性區域)
為了解決這個問題,BN層引入了兩個可學習的引數\(\gamma\)和\(\beta\),這樣,經過BN層normalization的資料其實是服從\(\beta\)均值,\(\gamma^2\)方差的資料。
所以對於某一層的網路來說,我們現在變成這樣的流程:
- \(Z=W\times input^i+b\)
- \(\hat{Z}=\gamma \times \frac{Z-\mu}{\sqrt{\sigma^2+\epsilon}}+\beta\)
- \(input^{i+1}=g(\hat{Z})\)
(上面公式中,省略了\(i\),總的來說是表示第i層的網路層產生第i+1層輸入資料的過程)
測試階段的BN
我們知道BN在每一層計算的\(\mu\)與\(\sigma^2\) 都是基於當前batch中的訓練資料,但是這就帶來了一個問題:我們在預測階段,有可能只需要預測一個樣本或很少的樣本,沒有像訓練樣本中那麼多的資料,這樣的\(\sigma^2\)和\(\mu\)要怎麼計算呢?
利用訓練集訓練好模型之後,其實每一層的BN層都保留下了每一個batch算出來的\(\mu\)和\(\sigma^2\).然後呢利用整體的訓練集來估計測試集的\(\mu_{test}\)和\(\sigma_{test}^2\)
\(\mu_{test}=E(\mu_{train})\)
\(\sigma_{test}^2=\frac{m}{m-1}E(\sigma_{train}^2)\)
然後再對測試機進行BN層:
當然,計算訓練集的\(\mu\)和\(\simga\)的方法除了上面的求均值之外。吳恩達老師在其課程中也提出了,可以使用指數加權平均的方法。不過都是同樣的道理,根據整個訓練集來估計測試機的均值方差。
BN層的好處有哪些
-
BN使得網路中每層輸入資料的分佈相對穩定,加速模型學習速度。
BN通過規範化與線性變換使得每一層網路的輸入資料的均值與方差都在一定範圍內,使得後一層網路不必不斷去適應底層網路中輸入的變化,從而實現了網路中層與層之間的解耦,允許每一層進行獨立學習,有利於提高整個神經網路的學習速度。 -
BN允許網路使用飽和性啟用函式(例如sigmoid,tanh等),緩解梯度消失問題
通過normalize操作可以讓啟用函式的輸入資料落在梯度非飽和區,緩解梯度消失的問題;另外通過自適應學習\(\gamma\)與 \(\beta\) 又讓資料保留更多的原始資訊。 -
BN具有一定的正則化效果
在Batch Normalization中,由於我們使用mini-batch的均值與方差作為對整體訓練樣本均值與方差的估計,儘管每一個batch中的資料都是從總體樣本中抽樣得到,但不同mini-batch的均值與方差會有所不同,這就為網路的學習過程中增加了隨機噪音
BN與其他normalizaiton的比較
【weight normalization】
Weight Normalization是對網路權值進行normalization,也就是L2 norm。
相對於BN有下面的優勢:
- WN通過重寫神經網路的權重的方式來加速網路引數的收斂,不依賴於mini-batch。BN因為以來minibatch所以BN不能用於RNN網路,而WN可以。而且BN要儲存每一個batch的均值方差,所以WN節省記憶體;
- BN的優點中有正則化效果,但是新增噪音不適合對噪聲敏感的強化學習、GAN等網路。WN可以引入更小的噪音。
但是WN要特別注意引數初始化的選擇。
【Layer normalization】
更常見的比較是BN與LN的比較。
BN層有兩個缺點:
- 無法進行線上學習,因為線上學習的mini-batch為1;LN可以
- 之前提到的BN不能用在RNN中;LN可以
- 消耗一定的記憶體來記錄均值和方差;LN不用
但是,在CNN中LN並沒有取得比BN更好的效果。
參考連結:
- https://zhuanlan.zhihu.com/p/34879333
- https://www.zhihu.com/question/59728870
- https://zhuanlan.zhihu.com/p/113233908
- https://www.zhihu.com/question/55890057/answer/267872896