LSTM神經網路

超人汪小建發表於2019-03-02

LSTM是什麼

LSTM即Long Short Memory Network,長短時記憶網路。它其實是屬於RNN的一種變種,可以說它是為了克服RNN無法很好處理遠距離依賴而提出的。

我們說RNN不能處理距離較遠的序列是因為訓練時很有可能會出現梯度消失,即通過下面的公式訓練時很可能會發生指數縮小,讓RNN失去了對較遠時刻的感知能力。

解決思路

RNN梯度消失不應該是由我們學習怎麼去避免,而應該通過改良讓迴圈神經網路自己具備避免梯度消失的特性,從而讓迴圈神經網路自身具備處理長期序列依賴的能力。

RNN的狀態計算公式為

,根據鏈式求導法則會導致梯度變為連乘的形式,而sigmoid小於1會讓連乘小得很快。為了解決這個問題,科學家採用了累加的形式,

,其導數也為累加,從而避免梯度消失。LSTM即是使用了累加形式,但它的實現較複雜,下面進行介紹。

LSTM模型

回顧一下RNN的模型,如下圖,展開後多個時刻隱層互相連線,而所有迴圈神經網路都有一個重複的網路模組,RNN的重複網路模組很簡單,如下下圖,比如只有一個tanh層。

這裡寫圖片描述
這裡寫圖片描述

這裡寫圖片描述
這裡寫圖片描述

而LSTM的重複網路模組的結構則複雜很多,它實現了三個門計算,即遺忘門、輸入門和輸出門。每個門負責是事情不一樣,遺忘門負責決定保留多少上一時刻的單元狀態到當前時刻的單元狀態;輸入門負責決定保留多少當前時刻的輸入到當前時刻的單元狀態;輸出門負責決定當前時刻的單元狀態有多少輸出。

這裡寫圖片描述
這裡寫圖片描述

每個LSTM包含了三個輸入,即上時刻的單元狀態、上時刻LSTM的輸出和當前時刻輸入。

LSTM的機制

這裡寫圖片描述
這裡寫圖片描述

根據上圖我們們一步一步來看LSTM神經網路是怎麼運作的。

首先看遺忘門,用來計算哪些資訊需要忘記,通過sigmoid處理後為0到1的值,1表示全部保留,0表示全部忘記,於是有

其中中括號表示兩個向量相連合並,W_f是遺忘門的權重矩陣,sigma為sigmoid函式,b_f為遺忘門的偏置項。設輸入層維度為d_x,隱藏層維度為d_h,上面的狀態維度為d_c,則W_f的維度為d_c imes(d_h+d_x)。

這裡寫圖片描述
這裡寫圖片描述

其次看輸入門,輸入門用來計算哪些資訊儲存到狀態單元中,分兩部分,第一部分為

該部分可以看成當前輸入有多少是需要儲存到單元狀態的。第二部分為

該部分可以看成當前輸入產生的新資訊來新增到單元狀態中。結合這兩部分來建立一個新記憶。

這裡寫圖片描述
這裡寫圖片描述

而當前時刻的單元狀態由遺忘門輸入和上一時刻狀態的積加上輸入門兩部分的積,即

這裡寫圖片描述
這裡寫圖片描述

最後看看輸出門,通過sigmoid函式計算需要輸出哪些資訊,再乘以當前單元狀態通過tanh函式的值,得到輸出。

這裡寫圖片描述
這裡寫圖片描述

LSTM的訓練

化繁為簡,這裡只討論包含一個LSTM層的三層神經網路(如果有多個層則誤差項除了沿時間反向傳播外,還會向上一層傳播),LSTM向前傳播時與三個門相關的公式如下,

需要學習的引數挺多的,同時也可以看到LSTM的輸出h_t有四個輸入分量加權影響,即三個門相關的 f_t i_t c_t o_t,而且其中權重W都是拼接的,所以在學習時需要分割出來,即

相關閱讀:
迴圈神經網路
卷積神經網路
機器學習之神經網路
機器學習之感知器
神經網路的交叉熵損失函式

========廣告時間========

鄙人的新書《Tomcat核心設計剖析》已經在京東銷售了,有需要的朋友可以到 item.jd.com/12185360.ht… 進行預定。感謝各位朋友。

為什麼寫《Tomcat核心設計剖析》

=========================

歡迎關注:

這裡寫圖片描述
這裡寫圖片描述

相關文章