LSTM - 長短期記憶網路

renyuzhuo發表於2021-02-08

迴圈神經網路(RNN)

人們不是每一秒都從頭開始思考,就像你閱讀本文時,不會從頭去重新學習一個文字,人類的思維是有持續性的。傳統的卷積神經網路沒有記憶,不能解決這一個問題,迴圈神經網路(Recurrent Neural Networks)可以解決這一個問題,在迴圈神經網路中,通過迴圈可以解決沒有記憶的問題,如下圖:

RNN-rolled

看到這裡,你可能還是不理解為什迴圈神經網路就可以有記憶。我們把這個圖展開:

RNN-unrolled

可以看出,我們輸入 \(X_0\) 後,首先警告訓練,得到輸出 \(h_0\),同時會把這個輸出傳遞給下一次訓練 \(X_1\),普通的神經網路是不會這樣做的,這時對 \(X_1\) 進行訓練時,輸入就包括了 \(X_1\) 本身和 訓練 \(X_0\) 的輸出,前面的訓練對後面有印象,同樣的道理,之後的每一次訓練都收到了前面的輸出的影響(對 \(X_1\) 訓練的輸出傳遞給訓練 \(X_2\) 的過程,\(X_0\)\(X_2\) 的影響是間接的)。

遇到的問題

迴圈神經網路很好用,但是還有一些問題,主要體現在沒辦法進行長期記憶。我們可以想象(也有論文證明),前期的某一次輸入,在較長的鏈路上傳遞時,對後面的影響越來越小,相當於網路有一定的記憶能力,但是記憶力只有 7 秒,很快就忘記了,如下圖 \(X_0\)\(X_1\)\(h_{t+1}\) 的影響就比較小了(理論上通過調整引數避免這個問題,但是尋找這個引數太難了,實踐中不好應用,因此可以近似認為不可行),LSTM 的提出就是為了解決這個問題的。

RNN-longtermdependencies

LSTM

LSTM(Long Short Term Memory)本質還是一種 RNN,只不過其中的那個迴圈,上圖中的那個 A被重新設計了,目的就是為了解決記憶時間不夠長的問題,其他神經網路努力調整引數為的是使記憶力更好一點,結果 LSTM 天生過目不忘,簡直降維打擊!

普通的 RNN 中的 A 如下圖,前一次的輸入和本次的輸入,進行一次運算,圖中用的是 tanh:

LSTM3-SimpleRNN

相比較起來,LSTM 中的 A 就顯得複雜了好多,不是上圖單一的神經網路層,而是有四層,如下圖,並且似乎這麼看還有點看不懂,這就是本文需要重點分析的內容,仔細認真讀下去,定會有收穫:

LSTM3-chain

定義一些圖形的含義,黃色方框是簡單的神經網路層;粉色的代表逐點操作,如加法乘法;還有合併和分開(拷貝)操作:

LSTM2-notation

核心思想

首先看下圖高亮部分,前一次的輸出,可以幾乎沒有阻礙的一直沿著這條高速公路流動,多麼簡單樸素的思想,既然希望前面的訓練不被遺忘,那就一直傳遞下去:

LSTM3-C-line

當然,為了讓這種傳遞更加有意義,需要加入一些門的控制,這種門具有選擇性,可以完全通過,可以完全不通過,可以部分通過,S 函式(Sigmoid)可以達到這樣的目的,下面這樣就是一個簡單的門:

LSTM3-gate

總結一下,我們構造 LSTM 網路,這個網路有能力讓前面的資料傳遞到最後,網路具有長期記憶的能力,同時也有門的控制,及時捨棄那些無用的記憶。

詳細分析

有了這樣的核心思想,再看這個網路就簡單了好多,從左到右第一層是“選擇性忘記”。我們根據前一次的輸出和本次的輸入,通過 Sigmoid 判斷出前一次哪些記憶需要保留和忘記:

LSTM3-focus-f

第二部分又分為了兩個部分,一個部分是“輸入門層”,用 Sigmoid 決定哪些資訊需要進行更新,另一個部分是建立候選值向量,即本次輸入和上次輸出進行初步計算後的中間狀態:

LSTM3-focus-i

經過前面的計算,我們可以更新單元格的狀態了。第一步,前一個的單元格哪些資料需要傳遞,哪些資料需要忘記;第二步,本次的哪些資料需要更新,乘以本次計算的中間狀態可以得到本次的更新資料;再把前兩步的資料相加,就是新的單元格狀態,可以繼續向後傳遞。

LSTM3-focus-C

這一步需要決定我們的輸出:第一步,我們用 Sigmoid 來判斷我們需要輸出的部分;第二步,把上面計算得到的單元格狀態通過 tanh 計算將資料整理到 -1 到 1 的區間內;第三步,把第一步和第二步的資料相乘,就得到了最後的輸出:

LSTM3-focus-o

總結一下我們剛剛做了什麼:我們首先通過本次的輸入和上次的輸出,判斷出上次單元格狀態有哪些資料需要保留或捨棄,再根據本次的輸入進行網路訓練,進一步得到本次訓練的單元格狀態和輸出,並將單元格狀態和本次的輸出繼續往後傳遞。

這裡有一個疑問,為什麼需要捨棄?舉個例子,翻譯一篇文章,一篇文章前一段介紹某一個人的詳細資訊和背景,下一段介紹今天發生的某個故事,兩者的關係是弱耦合的,需要及時捨棄前面對人背景資訊的記憶,才會更好的翻譯下面的故事。

其他一些基於 LSTM 修改版的網路,本質是一樣的,只不過把某些地方打通了,有論文驗證過,一般情況下對訓練的結果影響很小,這裡不展開介紹,大同小異,修內功而不是那些奇奇怪怪的招式:

LSTM3-var-peepholes

LSTM3-var-tied

LSTM3-var-GRU

總結

本文介紹了長短期記憶網路,在大多數情況下,若在某個領域用 RNN 取得了比較好的效果,其很可能就是使用的 LSTM。這是一篇好文,本文圖片來自Understanding-LSTMs,值得一讀。

  • 本文首發自: RAIS

相關文章