三次簡化一張圖: 一招理解LSTM/GRU門控機制

AI科技大本營發表於2019-02-26

作者 | 張皓


引言

RNN是深度學習中用於處理時序資料的關鍵技術, 目前已在自然語言處理, 語音識別, 視訊識別等領域取得重要突破, 然而梯度消失現象制約著RNN的實際應用。LSTM和GRU是兩種目前廣為使用的RNN變體,它們通過門控機制很大程度上緩解了RNN的梯度消失問題,但是它們的內部結構看上去十分複雜,使得初學者很難理解其中的原理所在。本文介紹”三次簡化一張圖”的方法,對LSTM和GRU的內部結構進行分析。該方法非常通用,適用於所有門控機制的原理分析。

預備知識: RNN

RNN (recurrent neural networks, 注意不是recursiveneural networks)提供了一種處理時序資料的方案。和n-gram只能根據前n-1個詞來預測當前詞不同, RNN理論上可以根據之前所有的詞預測當前詞。在每個時刻, 隱層的輸出ht依賴於當前詞輸入xt和前一時刻的隱層狀態ht-1:

三次簡化一張圖: 一招理解LSTM/GRU門控機制

其中:=表示”定義為”, sigm代表sigmoid函式sigm(z):=1/(1+exp(-z)), Wxh和Whh是可學習的引數。結構見下圖:

三次簡化一張圖: 一招理解LSTM/GRU門控機制

圖中左邊是輸入,右邊是輸出。xt是當前詞,ht-1記錄了上文的資訊。xt和ht-1在分別乘以Wxh和Whh之後相加,再經過tanh非線性變換,最終得到ht。

在反向傳播時,我們需要將RNN沿時間維度展開,隱層梯度在沿時間維度反向傳播時需要反覆乘以引數。因此, 儘管理論上RNN可以捕獲長距離依賴, 但實際應用中,根據譜半徑(spectralradius)的不同,RNN將會面臨兩個挑戰:梯度爆炸(gradient explosion)和梯度消失(vanishing gradient)。梯度爆炸會影響訓練的收斂,甚至導致網路不收斂;而梯度消失會使網路學習長距離依賴的難度增加。這兩者相比, 梯度爆炸相對比較好處理,可以用梯度裁剪(gradientclipping)來解決,而如何緩解梯度消失是RNN及幾乎其他所有深度學習方法研究的關鍵所在。

LSTM

LSTM通過設計精巧的網路結構來緩解梯度消失問題,其數學上的形式化表示如下:

三次簡化一張圖: 一招理解LSTM/GRU門控機制

其中代表逐元素相乘。這個公式看起來似乎十分複雜,為了更好的理解LSTM的機制, 許多人用圖來描述LSTM的計算過程, 比如下面的幾張圖:

三次簡化一張圖: 一招理解LSTM/GRU門控機制

似乎看完了這些圖之後,你對LSTM的理解還是一頭霧水? 這是因為這些圖想把LSTM的所有細節一次性都展示出來,但是突然暴露這麼多的細節會使你眼花繚亂,從而無處下手。

因此,本文提出的方法旨在簡化門控機制中不重要的部分,從而更關注在LSTM的核心思想。整個過程是“三次簡化一張圖”,具體流程如下:

  • 第一次簡化: 忽略門控單元i,f,o的來源。3個門控單元的計算方法完全相同, 都是由輸入經過線性對映得到的, 區別只是計算的引數不同。這樣做的目的是為了梯度反向傳導時能對門控單元進行更新。這不是LSTM的核心思想, 在進行理解時,我們可以假定各門控單元是給定的。

  • 第二次簡化: 考慮一維情況。LSTM中對各維是獨立進行門控的,所以為了理解方便,我們只需要考慮一維情況。

  • 第三次簡化: 各門控單元0/1輸出。 門控單元輸出是[0,1]實數區間的原因是階躍啟用函式無法反向傳播進行優化, 所以各門控單元使用sigmoid啟用函式去近似階躍函式。 因此, 為了理解方便, 我們只需要考慮理想情況, 即各門控單元是{0,1}二值輸出的,即門控單元扮演了電路中”開關”的角色, 用於控制資訊傳輸的通斷。

  • 一張圖: 將三次簡化的結果用”電路圖”表述出來,左邊是輸入,右邊是輸出。另外需要特別注意的是LSTM中的c實質上起到了RNN中h的作用, 這點在其他文獻資料中不常被提到。最終結果如下:

和RNN相同的是,網路接受兩個輸入,得到一個輸出。不同之處在於, LSTM中通過3個門控單元來對記憶單元c的資訊進行互動。

根據這張圖,我們可以對LSTM中各單元作用進行分析:

  • 輸入門it: it控制當前詞xt的資訊融入記憶單元ct。在理解一句話時,當前詞xt可能對整句話的意思很重要,也可能並不重要。輸入門的目的就是判斷當前詞xt對全域性的重要性。當it開關開啟的時候,網路將不考慮當前輸入xt。

  • 遺忘門ft: ft控制上一時刻記憶單元ct-1的資訊融入記憶單元ct。在理解一句話時,當前詞xt可能繼續延續上文的意思繼續描述,也可能從當前詞xt開始描述新的內容,與上文無關。和輸入門it相反, ft不對當前詞xt的重要性作判斷, 而判斷的是上一時刻的記憶單元ct-1對計算當前記憶單元ct的重要性。當ft開關開啟的時候,網路將不考慮上一時刻的記憶單元ct-1。

  • 輸出門ot: 輸出門的目的是從記憶單元ct產生隱層單元ht。並不是ct中的全部資訊都和隱層單元ht有關,ct可能包含了很多對ht無用的資訊,因此, ot的作用就是判斷ct中哪些部分是對ht有用的,哪些部分是無用的。

  • 記憶單元ct:ct綜合了當前詞xt和前一時刻記憶單元ct-1的資訊。這和ResNet中的殘差逼近思想十分相似,通過從ct-1到ct的”短路連線”, 梯度得已有效地反向傳播。 當ft處於閉合狀態時, ct的梯度可以直接沿著最下面這條短路線傳遞到ct-1,不受引數W的影響,這是LSTM能有效地緩解梯度消失現象的關鍵所在。

GRU

GRU是另一種十分主流的RNN衍生物。RNN和LSTM都是在設計網路結構用於緩解梯度消失問題, 只不過是網路結構有所不同。GRU在數學上的形式化表示如下:

三次簡化一張圖: 一招理解LSTM/GRU門控機制

為了理解GRU的設計思想,我們再一次運用“三次簡化一張圖”的方法來進行分析:

  • 第一次簡化: 忽略門控單元z, r的來源。

  • 第二次簡化: 考慮一維情況。

  • 第三次簡化: 各門控單元0/1輸出。這裡和LSTM略有不同的地方在於,GRU需要引入一個”單刀雙擲開關”。

  • 一張圖: 把三次簡化的結果用”電路圖”表述出來,左輸入,右輸出:

三次簡化一張圖: 一招理解LSTM/GRU門控機制

與LSTM相比,GRU將輸入門it和遺忘門ft融合成單一的更新門zt,並且融合了記憶單元ct和隱層單元ht,所以結構上比LSTM更簡單一些。

根據這張圖,我們可以對GRU的各單元作用進行分析:

  • 重置門rt:rt用於控制前一時刻隱層單元ht-1對當前詞xt的影響。如果ht-1對xt不重要,即從當前詞xt開始表述了新的意思,與上文無關, 那麼rt開關可以開啟, 使得ht-1對xt不產生影響。

  • 更新門zt:zt用於決定是否忽略當前詞xt。類似於LSTM中的輸入門it, zt可以判斷當前詞xt對整體意思的表達是否重要。當zt開關接通下面的支路時,我們將忽略當前詞xt,同時構成了從ht-1到ht的”短路連線”,這梯度得已有效地反向傳播。和LSTM相同,這種短路機制有效地緩解了梯度消失現象, 這個機制於highwaynetworks十分相似。

小結

儘管RNN, LSTM,和GRU的網路結構差別很大,但是他們的基本計算單元是一致的,都是對xt和ht-1做一個線性對映加tanh啟用函式,見三個圖的紅色框部分。他們的區別在於如何設計額外的門控機制控制梯度資訊傳播用以緩解梯度消失現象。LSTM用了3個門,GRU用了2個,那能不能再少呢? MGU (minimal gate unit)嘗試對這個問題做出回答, 它只有一個門控單元。

最後留個小練習, 參考LSTM和GRU的例子,你能不能用“三次簡化一張圖”的方法來分析一下MGU呢?

參考文獻

1. Bengio, Yoshua, PatriceSimard, and Paolo Frasconi。 “Learning long-term dependencies with gradient descent isdifficult。” IEEE transactions on neural networks 5。2 (1994):157-166。

2. Cho, Kyunghyun, et al。”Learning phrase representations using RNN encoder-decoder for statisticalmachine translation。” arXiv preprint arXiv:1406。1078 (2014)。

3. Chung, Junyoung, et al。”Empirical evaluation of gated recurrent neural networks on sequencemodeling。” arXiv preprint arXiv:1412。3555 (2014)。

4. Gers, Felix。 “Longshort-term memory in recurrent neural networks。” UnpublishedPhD dissertation, Ecole Polytechnique Fédérale de Lausanne, Lausanne, Switzerland(2001)。

5. Goodfellow, Ian, YoshuaBengio, and Aaron Courville。 Deep learning。 MIT press, 2016。

6. Graves, Alex。 Supervisedsequence labelling with recurrent neural networks。 Vol。 385。 Heidelberg:Springer, 2012。

7. Greff, Klaus, et al。 “LSTM:A search space odyssey。” IEEE transactions on neural networks and learning systems(2016)。

8. He, Kaiming, et al。 “Deepresidual learning for image recognition。” Proceedingsof the IEEE conference on computer vision and pattern recognition。 2016。

9. He, Kaiming, et al。”Identity mappings in deep residual networks。” EuropeanConference on Computer Vision。 Springer International Publishing, 2016。

10. Hochreiter, Sepp, and JürgenSchmidhuber。 “Long short-term memory。” Neuralcomputation 9。8 (1997): 1735-1780。

11. Jozefowicz, Rafal, WojciechZaremba, and Ilya Sutskever。 “An empirical exploration of recurrent network architectures。” Proceedingsof the 32nd International Conference on Machine Learning (ICML-15)。 2015。

12. Li, Fei-Fei, JustinJohnson, and Serena Yeung。 CS231n: Convolutional Neural Networks for Visual Recognition。 Stanford。 2017。

13. Lipton, Zachary C。, JohnBerkowitz, and Charles Elkan。 “A critical review of recurrent neural networks for sequencelearning。” arXiv preprint arXiv:1506。00019 (2015)。

14. Manning, Chris andRichard Socher。 CS224n: Natural Language Processing with Deep Learning。 Stanford。 2017。

15. Pascanu, Razvan, Tomas Mikolov, and YoshuaBengio。 “On the difficulty of training recurrent neural networks。”International Conference on Machine Learning。 2013。

16. Srivastava, RupeshKumar, Klaus Greff, and Jürgen Schmidhuber。 “Highwaynetworks。” arXiv preprint arXiv:1505。00387 (2015)。

17. Williams, D。 R。 G。 H。 R。, andGeoffrey Hinton。 “Learning representations by back-propagating errors。”Nature 323。6088 (1986): 533-538。

18. Zhou, Guo-Bing, et al。”Minimal gated unit for recurrent neural networks。”International Journal of Automation and Computing 13。3 (2016):226-234。

本文是投稿文章,作者:張皓

github地址:https://github.com/HaoMood/

注:AI科技大本營現已開通投稿通道,投稿請加編輯微信1092722531

相關文章