RNN梯度消失與梯度爆炸的原因

貪心科技發表於2019-01-17

• 關於RNN結構

• 關於RNN前向傳播

• 關於RNN反向傳播

• 解決方法

1. 關於RNN結構

迴圈神經網路RNN(Recurrent Neural Network)是用於處理序列資料的一種神經網路,已經在自然語言處理中被廣泛應用。下圖為經典RNN結構:

RNN梯度消失與梯度爆炸的原因

RNN結構

2. 關於RNN前向傳播

RNN前向傳導公式:

RNN梯度消失與梯度爆炸的原因其中:  St :  t 時刻的隱含層狀態值

Ot :  t 時刻的輸出值

① 是隱含層計算公式,U是輸入x的權重矩陣,W是時刻t-1的狀態值

St-1作為輸入的權重矩陣,Φ是啟用函式

② 是輸出層計算公式,V是輸出層的權重矩陣,f是啟用函式

損失函式(loss function)採用交叉熵RNN梯度消失與梯度爆炸的原因( Ot 是t時刻預測輸出, RNN梯度消失與梯度爆炸的原因是 t 時刻正確的輸出) 

那麼對於一次訓練任務中,損失函式RNN梯度消失與梯度爆炸的原因, T 是序列總長度。

假設初始狀態St為0,t=3 有三段時間序列時,由 ① 帶入②可得到 

t1、t2、t3 各個狀態和輸出

RNN梯度消失與梯度爆炸的原因

RNN梯度消失與梯度爆炸的原因

RNN梯度消失與梯度爆炸的原因

3. 關於RNN反向傳播

BPTT(back-propagation through time)演算法是針對循層的訓練演算法,它的基本原理和BP演算法一樣。其演算法本質還是梯度下降法,那麼該演算法的關鍵就是計算各個引數的梯度,對於RNN來說引數有 U、W、V。

RNN梯度消失與梯度爆炸的原因

反向傳播

RNN梯度消失與梯度爆炸的原因可以簡寫成:

RNN梯度消失與梯度爆炸的原因

RNN梯度消失與梯度爆炸的原因觀察③④⑤式,可知,對於 V 求偏導不存在依賴問題;但是對於 W、U 求偏導的時候,由於時間序列長度,存在長期依賴的情況。主要原因可由 t=1、2、3 的情況觀察得 , St會隨著時間序列向前傳播,同時St是 U、W 的函式。

前面得出的求偏導公式⑥,取其中累乘的部分出來,其中啟用函式 Φ 通常是:tanh 則

RNN梯度消失與梯度爆炸的原因

RNN梯度消失與梯度爆炸的原因

由上圖可知當啟用函式是tanh函式時,tanh函式的導數最大值為1,又不可能一直都取1這種情況,而且這種情況很少出現,那麼也就是說,大部分都是小於1的數在做累乘,若當t很大的時候,RNN梯度消失與梯度爆炸的原因趨向0,舉個例子:0.850=0.00001427247也已經接近0了,這是RNN中梯度消失的原因。

再看⑦部分:

RNN梯度消失與梯度爆炸的原因

tanh’,還需要網路引數 W ,如果引數 W 中的值太大,隨著序列長度同樣存在長期依賴的情況,那麼產生問題就是梯度爆炸,而不是梯度消失了,在平時運用中,RNN比較深,使得梯度爆炸或者梯度消失問題會比較明顯。

4. 解決方法

面對梯度消失問題,可以採用ReLu作為啟用函式,下圖為ReLu函式

RNN梯度消失與梯度爆炸的原因

ReLU函式在定義域大於0部分的導數恆等於1,這樣可以解決梯度消失的問題,(雖然恆等於1很容易發生梯度爆炸的情況,但可透過設定適當的閾值可解決)。

另外計算方便,計算速度快,可以加速網路訓練。但是,定義域負數部分恆等於零,這樣會造成神經元無法啟用(可透過合理設定學習率,降低發生的機率)。

ReLU有優點也有缺點,其中的缺點可以透過其他操作取避免或者減低發生的機率,是目前使用最多的啟用函式

還可以透過更改內部結構來解決梯度消失和梯度爆炸問題,那就是LSTM了~!

知乎原文連結 :

 https://zhuanlan.zhihu.com/p/53405950

相關文章