深度學習中 Batch Normalization
這是一個還在被廣泛研究的問題,先把簡單的總結寫起來,後面肯定是要更新的。資料經過歸一化和標準化後可以加快梯度下降的求解速度,這就是Batch Normalization等技術非常流行的原因,它使得可以使用更大的學習率更穩定地進行梯度傳播,甚至增加網路的泛化能力。
1 什麼是歸一化/標準化
Normalization是一個統計學中的概念,我們可以叫它歸一化或者規範化,它並不是一個完全定義好的數學操作(如加減乘除)。它通過將資料進行偏移和尺度縮放調整,在資料預處理時是非常常見的操作,在網路的中間層如今也很頻繁的被使用。
1. 線性歸一化
最簡單來說,歸一化是指將資料約束到固定的分佈範圍,比如8點陣圖像的0~255畫素值,比如0~1。
在數字影像處理領域有一個很常見的線性對比度拉伸操作:
X=(x-xmin)/(xmax-mxin)
它常常可以實現下面的增強對比度的效果。
不過以上的歸一化方法有個非常致命的缺陷,當X最大值或者最小值為孤立的極值點,會影響效能。
2. 零均值歸一化/Z-score標準化
零均值歸一化也是一個常見的歸一化方法,被稱為標準化方法,即每一變數值與其平均值之差除以該變數的標準差。
經過處理後的資料符合均值為0,標準差為1的分佈, 如果原始的分佈是正態分佈,那麼z-score標準化就將原始的正態分佈轉換為標準正態分佈,機器學習中的很多問題都是基於正態分佈的假設,這是更加常用的歸一化方法 。以上兩種方法都是線性變換,對輸入向量X按比例壓縮再進行平移,操作之後原始有量綱的變數變成無量綱的變數。不過它們不會改變分佈本身的形狀,下面以一個指數分佈為例:
如果要改變分佈本身的形狀,下面也介紹兩種。
3.正態分佈Box-Cox變換
box-cox變換可以將一個非正態分佈轉換為正態分佈,使得分佈具有對稱性,變換公式如下:
在這裡lamda是一個基於資料求取的待定變換引數,Box-Cox的效果如下。
4. 直方圖均衡化
直方圖均衡也可以將某一個分佈歸一化到另一個分佈,它通過影像的灰度值分佈,即影像直方圖來對影像進行對比度進調整,可以增強區域性的對比度。它的變換步驟如下:
(1)計算概率密度和累積概率密度。
(2)建立累積概率到灰度分佈範圍的單調線性對映T。
(3)根據T進行原始灰度值到新灰度值的對映。直方圖均衡化將任意的灰度範圍對映到全域性灰度範圍之間,對於8位的影像就是(0,255),它相對於直接線性拉伸,讓分佈更加均勻,對於增強相近灰度的對比度很有效,如下圖。
綜上,歸一化資料的目標,是為了讓資料的分佈變得更加符合期望,增強資料的表達能力。
在深度學習中,因為網路的層數非常多,如果資料分佈在某一層開始有明顯的偏移,隨著網路的加深這一問題會加劇(這在BN的文章中被稱之為internal covariate shift),進而導致模型優化的難度增加,甚至不能優化。所以,歸一化就是要減緩這個問題。
2 Batch Normalization
1 基本原理
現在一般採用批梯度下降方法對深度學習進行優化,這種方法把資料分為若干組,按組來更新引數,一組中的資料共同決定了本次梯度的方向,下降時減少了隨機性。另一方面因為批的樣本數與整個資料集相比小了很多,計算量也下降了很多。
Batch Normalization(簡稱BN)中的batch就是批量資料,即每一次優化時的樣本數目,通常BN網路層用在卷積層後,用於重新調整資料分佈。假設神經網路某層一個batch的輸入為X=[x1,x2,...,xn],其中xi代表一個樣本,n為batch size。首先,我們需要求得mini-batch裡元素的均值:
接下來,求取mini-batch的方差:
這樣我們就可以對每個元素進行歸一化。
最後進行尺度縮放和偏移操作,這樣可以變換回原始的分佈,實現恆等變換,這樣的目的是為了補償網路的非線性表達能力,因為經過標準化之後,偏移量丟失。具體的表達如下,yi就是網路的最終輸出。
假如gamma等於方差,beta等於均值,就實現了恆等變換。
從某種意義上來說,gamma和beta代表的其實是輸入資料分佈的方差和偏移。對於沒有BN的網路,這兩個值與前一層網路帶來的非線性性質有關,而經過變換後,就跟前面一層無關,變成了當前層的一個學習引數,這更加有利於優化並且不會降低網路的能力。
對於CNN,BN的操作是在各個特徵維度之間單獨進行,也就是說各個通道是分別進行Batch Normalization操作的。
如果輸出的blob大小為(N,C,H,W),那麼在每一層normalization就是基於 NHW個數值進行求平均以及方差的操作 ,記住這裡我們後面會進行比較。
2.BN帶來的好處。
(1) 減輕了對引數初始化的依賴,這是利於調參的朋友們的。
(2) 訓練更快,可以使用更高的學習率。
(3) BN一定程度上增加了泛化能力,dropout等技術可以去掉。
3.BN的缺陷
從上面可以看出,batch normalization依賴於batch的大小,當batch值很小時,計算的均值和方差不穩定。研究表明對於ResNet類模型在ImageNet資料集上,batch從16降低到8時開始有非常明顯的效能下降,在訓練過程中計算的均值和方差不準確,而在測試的時候使用的就是訓練過程中保持下來的均值和方差。
這一個特性,導致batch normalization不適合以下的幾種場景。
(1)batch非常小,比如訓練資源有限無法應用較大的batch,也比如線上學習等使用單例進行模型引數更新的場景。(2)rnn,因為它是一個動態的網路結構,同一個batch中訓練例項有長有短,導致每一個時間步長必須維持各自的統計量,這使得BN並不能正確的使用。在rnn中,對bn進行改進也非常的困難。不過,困難並不意味著沒人做,事實上現在仍然可以使用的,不過這超出了我們們初識境的學習範圍。
4.BN的改進
針對BN依賴於batch的這個問題,BN的作者親自現身提供了改進,即在原來的基礎上增加了一個仿射變換。
其中引數r,d就是仿射變換引數,它們本身是通過如下的方式進行計算的
其中引數都是通過滑動平均的方法進行更新
所以r和d就是一個跟樣本有關的引數,通過這樣的變換來進行學習,這兩個引數在訓練的時候並不參與訓練。
在實際使用的時候,先使用BN進行訓練得到一個相對穩定的移動平均,網路迭代的後期再使用剛才的方法,稱為Batch Renormalization,當然r和d的大小必須進行限制。
3 Batch Normalization的變種
Normalization思想非常簡單,為深層網路的訓練做出了很大貢獻。因為有依賴於樣本數目的缺陷,所以也被研究人員盯上進行改進。說的比較多的就是Layer Normalization與Instance Normalization,Group Normalization了。
前面說了Batch Normalization各個通道之間是獨立進行計算,如果拋棄對batch的依賴,也就是每一個樣本都單獨進行normalization,同時各個通道都要用到,就得到了Layer Normalization。
跟Batch Normalization僅針對單個神經元不同,Layer Normalization考慮了神經網路中一層的神經元。如果輸出的blob大小為(N,C,H,W),那麼在每一層Layer Normalization就是基於CHW個數值進行求平均以及方差的操作。
Layer Normalization把每一層的特徵通道一起用於歸一化,如果每一個特徵層單獨進行歸一化呢?也就是限制在某一個特徵通道內,那就是instance normalization了。
如果輸出的blob大小為(N,C,H,W),那麼在每一層Instance Normalization就是基於H*W個數值進行求平均以及方差的操作。對於風格化類的影像應用,Instance Normalization通常能取得更好的結果,它的使用本來就是風格遷移應用中提出。
Group Normalization是Layer Normalization和Instance Normalization 的中間體, Group Normalization將channel方向分group,然後對每個Group內做歸一化,算其均值與方差。
如果輸出的blob大小為(N,C,H,W),將通道C分為G個組,那麼Group Normalization就是基於GHW個數值進行求平均以及方差的操作。我只想說,你們真會玩,要榨乾所有可能性。
在Batch Normalization之外,有人提出了通用版本Generalized Batch Normalization,有人提出了硬體更加友好的L1-Norm Batch Normalization等,不再一一講述。
另一方面,以上的Batch Normalization,Layer Normalization,Instance Normalization都是將規範化應用於輸入資料x,Weight normalization則是對權重進行規範化,感興趣的可以自行了解,使用比較少,也不在我們的討論範圍。
這麼多的Normalization怎麼使用呢?有一些基本的建議吧,不一定是正確答案。
(1) 正常的處理圖片的CNN模型都應該使用Batch Normalization。只要保證batch size較大(不低於32),並且打亂了輸入樣本的順序。如果batch太小,則優先用Group Normalization替代。
(2)對於RNN等時序模型,有時候同一個batch內部的訓練例項長度不一(不同長度的句子),則不同的時態下需要儲存不同的統計量,無法正確使用BN層,只能使用Layer Normalization。
(3) 對於影像生成以及風格遷移類應用,使用Instance Normalization更加合適。
4 Batch Normalization的思考
最後是關於Batch Normalization的思考 ,應該說,normalization機制至今仍然是一個非常open的問題,相關的理論研究一直都有,大家最關心的是Batch Normalization怎麼就有效了 。之所以只說Batch Normalization,是因為上面的這些方法的差異主要在於計算normalization的元素集合不同。Batch Normalization是NHW,Layer Normalization是CHW,Instance Normalization是HW,Group Normalization是GH*W。關於Normalization的有效性,有以下幾個主要觀點:
(1) 主流觀點,Batch Normalization調整了資料的分佈,不考慮啟用函式,它讓每一層的輸出歸一化到了均值為0方差為1的分佈,這保證了梯度的有效性,目前大部分資料都這樣解釋,比如BN的原始論文認為的緩解了Internal Covariate Shift(ICS)問題。
(2) 可以使用更大的學習率,文[2]指出BN有效是因為用上BN層之後可以使用更大的學習率,從而跳出不好的區域性極值,增強泛化能力,在它們的研究中做了大量的實驗來驗證。
(3) 損失平面平滑。文[3]的研究提出,BN有效的根本原因不在於調整了分佈,因為即使是在BN層後模擬ICS,也仍然可以取得好的結果。它們指出,BN有效的根本原因是平滑了損失平面。之前我們說過,Z-score標準化對於包括孤立點的分佈可以進行更平滑的調整。
[1] Ioffe S, Szegedy C. Batch normalization: Accelerating deep network training by reducing internal covariate shift[J]. arXiv preprint arXiv:1502.03167, 2015.
[2] Bjorck N, Gomes C P, Selman B, et al. Understanding batch normalization[C]//Advances in Neural Information Processing Systems. 2018: 7705-7716.
[3] Santurkar S, Tsipras D, Ilyas A, et al. How does batch normalization help optimization?[C]//Advances in Neural Information Processing Systems. 2018: 2488-2498.
BN層技術的出現確實讓網路學習起來更加簡單了,降低了調參的工作量,不過它本身的作用機制還在被廣泛研究中。幾乎就像是深度學習中沒有open問題的一個縮影,BN到底為何,還無定論,如果你有興趣和時間,不妨也去踩一坑。
相關文章
- 【深度學習筆記】Batch Normalization (BN)深度學習筆記BATORM
- batch normalization學習理解筆記BATORM筆記
- 深度學習中的Normalization模型深度學習ORM模型
- 解毒batch normalizationBATORM
- 深度學習中的Normalization模型(附例項&公式)深度學習ORM模型公式
- TensorFlow實現Batch NormalizationBATORM
- Batch Normalization: 如何更快地訓練深度神經網路BATORM神經網路
- 深度學習當中的三個概念:Epoch, Batch, Iteration深度學習BAT
- BN(Batch Normalization)層的詳細介紹BATORM
- 淺談深度學習訓練中資料規範化(Normalization)的重要性深度學習ORM
- [PyTorch 學習筆記] 6.2 NormalizationPyTorch筆記ORM
- 深度學習——學習目錄——學習中……深度學習
- 關於深度學習上的一些術語: Epoch, Batch Size, Iteration深度學習BAT
- 深度學習中的epochs,batch_size,iterations詳解---對這三個概念說的比較清楚深度學習BAT
- 深度學習中的Dropout深度學習
- 神經網路 深度學習 專業術語解釋(Step, Batch Size, Iteration,Epoch)神經網路深度學習BAT
- 學習筆記:深度學習中的正則化筆記深度學習
- 深度學習+深度強化學習+遷移學習【研修】深度學習強化學習遷移學習
- 深度學習及深度強化學習研修深度學習強化學習
- 【深度學習】深度解讀:深度學習在IoT大資料和流分析中的應用深度學習大資料
- 深度學習學習框架深度學習框架
- ####深度學習深度學習
- 深度學習深度學習
- 深度 學習
- 深度學習及深度強化學習應用深度學習強化學習
- 深度學習在OC中的應用深度學習
- 深度學習中的優化方法(二)深度學習優化
- 深度學習中的優化方法(一)深度學習優化
- 淺談深度學習中的機率深度學習
- 讀懂深度學習,走進“深度學習+”階段深度學習
- Layer NormalizationORM
- 深度學習模型深度學習模型
- Python深度學習Python深度學習
- 深度學習引言深度學習
- MySQL深度學習MySql深度學習
- 深度學習-LSTM深度學習
- 深度學習《CycleGAN》深度學習
- 深度學習《StarGAN》深度學習