一文讀懂:梯度消失(爆炸)及其解決方法

忽逢桃林發表於2020-06-21

梯度消失問題和梯度爆炸問題,總的來說可以稱為梯度不穩定問題

【要背住的知識】:用ReLU代替Sigmoid,用BN層,用殘差結構解決梯度消失問題。梯度爆炸問題的話,可以用正則化來限制。sigmoid的導數是【0,0.25】.

出現原因

兩者出現原因都是因為鏈式法則。當模型的層數過多的時候,計算梯度的時候就會出現非常多的乘積項。用下面這個例子來理解:

這是每層只有1個神經元的例子,每個神經元的啟用函式都是sigmoid,然後我們想要更新b1這個引數。
按照大家都公認的符號來表示:

  • \(w_1\*x_1 + b_1 = z_1\)這就是z的含義;
  • \(\sigma(z_1)=a_1\),這是a的含義。

可以得到這個偏導數:
\(\frac{\partial C}{\partial b_1} = \frac{\partial z_1}{\partial b_1}\frac{\partial a_1}{\partial z_1} \frac{\partial z_2}{\partial a_2}\frac{\partial a_2}{\partial z_2} \frac{\partial z_2}{\partial a_3}\frac{\partial a_3}{\partial z_3} \frac{\partial z_3}{\partial a_4}\frac{\partial a_4}{\partial z_4} \frac{\partial C}{\partial a_4}\)

然後化簡:
\(\frac{\partial C}{\partial b_1}=\sigma'(z_1)w_2\sigma'(z_2)w_3\sigma'(z_3)w_4\sigma'(z_4)\frac{\partial C}{\partial a_4}\)

關鍵在於這個\(\sigma'(z_1)\),sigmoid函式的導數,是在0~0.25這個區間的,這意味著,當網路層數越深,那麼對於前面幾層的梯度,就會非常的小。下圖是sigmoid函式的導數的函式圖:

因此經常會有這樣的現象:

圖中,分別表示4層隱含層的梯度變化幅度。可以看到,最淺的那個隱含層,梯度更新的速度,是非常小的。【圖中縱軸是指數變化的】。

那麼梯度爆炸也很好理解,就是\(w_j\sigma'(z_j)>1\),這樣就爆炸了。
【注意:如果啟用函式是sigmoid,那麼其導數最大也就0.25,而\(w_j\)一般不會大於4的,所以sigmoid函式而言,一般都是梯度消失問題】

【總結】:

  1. 梯度消失和梯度爆炸是指前面幾層的梯度,因為鏈式法則不斷乘小於(大於)1的數,導致梯度非常小(大)的現象;
  2. sigmoid導數最大0.25,一般都是梯度消失問題。

解決方案

更換啟用函式

最常見的方案就是更改啟用函式,現在神經網路中,除了最後二分類問題的最後一層會用sigmoid之外,每一層的啟用函式一般都是用ReLU。

【ReLU】:如果啟用函式的導數是1,那麼就沒有梯度爆炸問題了。

【好處】:可以發現,relu函式的導數在正數部分,是等於1的,因此就可以避免梯度消失的問題。
【不好】:但是負數部分的導數等於0,這樣意味著,只要在鏈式法則中某一個\(z_j\)小於0,那麼這個神經元的梯度就是0,不會更新。

【leakyReLU】:在ReLU的負數部分,增加了一定的斜率:

解決了ReLU中會有死神經元的問題。

【elu】:跟LeakyReLU一樣是為了解決死神經元問題,但是增加的斜率不是固定的:

但是相比leakrelu,計算量更大。

batchnorm層

這個是非常給力的成功,在影像處理中必用的層了。BN層提出來的本質就是為了解決反向傳播中的梯度問題

在神經網路中,有這樣的一個問題:Internal Covariate Shift
假設第一層的輸入資料經過第一層的處理之後,得到第二層的輸入資料。這時候,第二層的輸入資料相對第一層的資料分佈,就會發生改變,所以這一個batch,第二層的引數更新是為了擬合第二層的輸入資料的那個分佈。然而到了下一個batch,因為第一層的引數也改變了,所以第二層的輸入資料的分佈相比上一個batch,又不太一樣了。然後第二層的引數更新方向也會發生改變。層數越多,這樣的問題就越明顯。

但是為了保證每一層的分佈不變的話,那麼如果把每一層輸出的資料都歸一化0均值,1方差不就好了?但是這樣就會完全學習不到輸入資料的特徵了。不管什麼資料都是服從標準正太分佈,想想也會覺得有點奇怪。所以BN就是增加了兩個自適應引數,可以通過訓練學習的那種引數。這樣吧每一層的資料都歸一化到\(\beta\)均值,\(\gamma\)標準差的正態分佈上。

【將輸入分佈變成正態分佈,是一種去除資料絕對差異,擴大相對差異的一種行為,所以BN層用在分類上效果的好的。對於Image-to-Image這種任務,資料的絕對差異也是非常重要的,所以BN層可能起不到相應的效果。】

殘差結構


殘差結構,簡單的理解,就是讓深層網路通過走捷徑,讓網路不那麼深層。這樣梯度消失的問題就緩解了。

正則化

之前提到的梯度爆炸問題,一般都是因為\(w_j\)過大造成的,那麼用L2正則化就可以解決問題。


喜歡的話請關注我們的微信公眾號~【你好世界煉丹師】。

  • 公眾號主要講統計學,資料科學,機器學習,深度學習,以及一些參加Kaggle競賽的經驗。
  • 公眾號內容建議作為課後的一些相關知識的補充,飯後甜點。
  • 此外,為了不過多打擾,公眾號每週推送一次,每次4~6篇精選文章。

微信搜尋公眾號:你好世界煉丹師。期待您的關注。

相關文章