迴圈神經網路

Galois發表於2020-03-14

RNN是什麼

迴圈神經網路即recurrent neural network,它的提出主要是為了處理序列資料,序列資料是什麼?就是前面的輸入和後面的輸入是有關聯的,比如一句話,前後的詞都是有關係的,“我肚子餓了,準備去xx”,根據前面的輸入判斷“xx”很大可能就是“吃飯”。這個就是序列資料。

迴圈神經網路有很多變種,比如LSTM、GRU等,這裡搞清楚基礎的迴圈神經網路的思想,對於理解其他變種就比較容易了。

與傳統神經網路區別

下圖是我們經典的全連線網路,從輸入層到兩個隱含層再到輸出層,四層之間都是全連線的,而且層內之間的節點不相連。這種網路模型對於序列資料的預測就基本無能為力,比如某句話的下一個單詞是什麼就很難處理。

迴圈神經網路則擅長處理序列資料,它會對前面的資訊進行記憶並且參與當前輸出的計算,理論上迴圈神經網路能處理任意長度的序列資料。

RNN模型

RNN模型最抽象的畫法就是下面這種了,但它不太好理解,因為它將時間維度擠壓了。其中x是輸入,U是輸出層到隱含層的權重,s是隱含層值,W則是上個時刻隱含層作為這個時刻輸入的權重,V是隱含層到輸出層的權重,o是輸出。

為方便理解,將上圖展開,現在可以清楚看到輸入x、隱層值s和輸出o都有了下標t,這個t表示時刻,t-1是上一時刻,t+1則是下一時刻。不同時刻輸入對應不同的輸出,而且上一時刻的隱含層會影響當前時刻的輸出。

那麼反應到神經元是怎樣的呢?如下圖,這下就更清晰了,輸入的3個神經元連線4個隱含層神經元,然後保留隱含層狀態用於下一刻參與計算。

RNN的正向傳播

這裡寫圖片描述
這裡寫圖片描述

RNN的訓練

假設損失函式為

E

在t時刻,根據誤差逆傳播,有

首先,我們來看看對V的求導,每個時刻t的誤差至於當前時刻的誤差相關,則

其次,對W求導,對於一個訓練樣本,所有時刻的誤差加起來才是這個樣本的誤差,某時刻t對W求偏導為,

其中


一直依賴上個時刻,某個樣本的總誤差是需要所有時刻加起來,不斷對某個時刻進行求偏導,誤差一直反向傳播到t為0時刻,則

其中


根據鏈式法則是會一直乘到k
時刻,k可以是0、1、2...,那麼上式可以表示成,

最後,對U求導,

通過上面實現梯度下降訓練。

梯度消失或梯度爆炸

對於tanh和sigmoid啟用函式的RNN,我們說它不能很好的處理較長的序列,這個是為什麼呢?簡單說就是因為RNN很容易會存在梯度消失或梯度爆炸問題,發生這種情況時RNN就捕捉不了很早之前的序列的影響。

為什麼會這樣?接著往下看,tanh和sigmoid的梯度大致如下圖所示,兩端的梯度值都基本接近0了,而從上面的求導公式可以看到

其中有個連乘操作,而向量函式對向量求導結果為一個Jacobian矩陣,元素為每個點的導數,離當前時刻越遠則會乘越多啟用函式的導數,指數型,本來就接近0的梯度再經過指數就更加小,基本忽略不計了,於是便接收不到遠距離的影響,這就是RNN處理不了較長序列的原因。

而當矩陣中的值太大時,經過指數放大,則會產生梯度爆炸。

梯度爆炸會導致程式NaN,可以設定一個梯度閾值來處理。

梯度消失則可以用ReLU來替代tanh和sigmoid啟用函式,或者用LSTM或GRU結構。

RNN簡單應用例子

比如可以做字元級別的預測,如下圖,假如這裡只有四種字元,樣本為"hello"單詞,則輸入h預測下個字元為e,e接著則輸出l,l則輸出l,最後輸入l則輸出o。

========廣告時間========

鄙人的新書《Tomcat核心設計剖析》已經在京東銷售了,有需要的朋友可以到 item.jd.com/12185360.ht… 進行預定。感謝各位朋友。

為什麼寫《Tomcat核心設計剖析》

=========================

歡迎關注:

相關文章