lstm(一) 演化之路

weixin_34119545發表於2019-01-08

遞迴神經網路引入了時序的反饋機制,在語音、音樂等時序訊號的分析上有重要的意義。 
Hochreiter(應該是Schmidhuber的弟子)在1991年分析了bptt帶來的梯度爆炸和消失問題,給學習演算法帶來了梯度 
震盪和學習困難等問題; 
Hochreater和Schmidhuber在1997年提出了LSTM的網路結構,引入CEC單元解決bptt的梯度爆炸和消失問題; 
Felix Gers(Schmidhuber是指導人之一)2001年的博士論文進一步改進了lstm的網路結構,增加了forget gate和peephole; 
Alex Graves(Schmidhuber的弟子)2006年提出了lstm的ctc訓練準則

第一步:RNN->基本lstm

參考文獻[1]和[2]

問題

問題一:gradient

BPTT學習演算法存在梯度爆炸和消失問題(gradient blow up or vanish),簡單通過local error flow分析如下: 
對RNN的隱層進行unfolding後,可以得到如下的遞推關係: 

ϑj(t)=fj(netj(t))iwijϑi(t+1)


可以理解為t+1時刻的error通過wij傳遞到t時刻的j節點。 
對於t時刻的節點u,通過bptt可以傳遞到t-q時刻的節點u,不難得到遞推關係如下: 

ϑv(tq)ϑu(t)=⎧⎩⎨fv(netv(t1))wuvfv(netv(tq))nlq1=1ϑlq1(tq+1)ϑu(t)wlq1vq=1q>1


由此可以遞推: 

ϑlq1(tq+1)ϑu(t)=flq1(netlq1(tq+1))lq2=1nϑlq2(tq+2)ϑu(t)wlq2lq1


......


lq=v以及l0=u最後可得: 

ϑv(tq)ϑu(t)=l1=1n...lq1=1nm=1qflm(netlm(tm))wlmlm1


所以後面的聯乘公式對於梯度的傳播影響很大,如果 

|flm(netlm(tm))wlmlm1|>1.0


會有梯度爆炸問題,反之,有梯度消失問題。 
直觀上理解,error向後傳播的時候,每向前傳遞一步都需要乘以對應的係數W,係數W的大小會導致梯度的異常。

 

問題二:conflict

  • input weight conflict 
    假設wji表示輸入層到隱層之間的連線,對於有些輸入希望儘可能通過,也就是wji比較大,但是另外一些無關的輸入可能希望儘可能遮蔽掉,也就是wji儘可能為0。而實際網路中的wji引數是跟輸入無關,對於所有的輸入,它的大小是一致的。由於缺少這種自動調節功能,從而導致學習比較困難。
  • output weight conflict 
    同理,隱層到輸出層之間也存在放行和遮蔽的conflict。

解決

這裡寫圖片描述
1997年Hochreiter和Schmidhuber首先提出了LSTM的網路結構,解決了傳統RNN的上面兩個問題。

問題一的solution

lstm通過引入CEC(constant error carrousel)單元解決了梯度沿時間尺度unfolding帶來的問題。 
首先梯度的遞推關係如下: 

ϑj(t)=fj(netj(t))iwijϑi(t+1)


要想梯度傳播沒有異常的話,最容易的想法是: 

fj(netj(t))wij=1


這裡進一步簡化這個問題,可以如下約定: 

fj(x)=x


wjj=1


wij=0(ij)


相比之前的RNN結構,做了如下改變: 
1. 矩陣Whh簡化為對角矩陣,也就是隻允許節點的自旋,不允許隱層的其他節點連線到本節點 
2. 啟用函式sigmoid替換為了線性函式f(x)=x

 

以上兩點保證了error可以無損由t時刻傳遞到t-1時刻,如上圖中的scj(t)=scj(t1)+gyinj,CEC是lstm的核心部件

問題二的solution

針對問題二,lstm引入了兩個gate:input gate(對應圖中的inj)可以控制某些輸入進入cell(對應圖中的cj)更新原來儲存的資訊,或者遮蔽輸入以保持cell儲存的資訊不變;output gate(對應圖中的outj)以控制cell的資訊對輸出產生多大程度的影響。 
以output gate為例,直觀上可以理解為(個人理解,歡迎討論): 
傳統RNN隱層到輸出層的連線是由引數wkj控制的,由於對所有的隱層輸出,wkj都是一視同仁的,這對於傳統的DNN模型來講是沒有問題,因為DNN模型不會考慮歷史資訊,DNN可以理解為一個簡單的函式,一個輸入有對應的輸出就可以了,沒有必要使用gate加以限制。 
但是對於RNN來講,這種一視同仁則是不合理的,因為它需要儲存歷史資訊,而且該歷史資訊會對當前的輸出產生影響(有些時候是至關重要的影響),比如說”I grew up in France… I speak fluent French”,此時的France對French的影響將會是決定性的。所以lstm引入了gate來改善這種一視同仁的不良現狀,為什麼gate會改善這種不良現狀呢? 
因為此時的輸出結果不再只由wkj來控制了,還會受到output gate開門或者關門的影響,而output gate開門或者關門的控制權是由woutji引數以及輸入x等引數控制的,所以此時的輸出多了一條watchdog,引數增多了控制的將會更精確。

第二步:lstm + forget gate

參考文獻[3]

問題

傳統的lstm存在一個問題:隨著時間序列的增多,lstm網路沒有重置的機制(比如兩句話合成一句話作為輸入的話,希望是在第一句話結束的時候進行reset),從而導致cell state容易發生飽和,進一步會導致cell state的輸出h(趨近於1)的梯度很小(sigmoid函式在x值很大的時候梯度趨向於0),阻礙了error的傳入;另一方面輸出h趨近於1,導致cell的輸出近似等於output gate的輸出,意味著網路喪失了memory的功能。

解決

這裡寫圖片描述
在傳統lstm的基礎之上,引入了forget gate。使用這種結構可以讓網路自動學習什麼時候應該reset。具體做法即為使用yφ替換原來的CEC的常量1.0,定義如下; 

netφj(t)=mwφjmym(t1)


yφj(t)=fφj(netφj(t))


forget gate的引入可以解決需要內部切割(hierarchical decomposition)但又不知道切割位置的序列問題。

 

第三步:lstm+peephole

參考文獻[3]

問題

lstm的gate的輸入包含兩個部分,網路輸入和上一時刻(t-1)網路的輸出。 
此時如果output gate關閉(值接近0)的話,網路的輸出(t時刻)將為0,下一時刻(t+1)網路gate將完全跟網路輸入有關,就會丟失歷史資訊。

解決

這裡寫圖片描述
增加CEC到各個gate之間的連線,使得CEC(const error carrousels)和gate之間存在雙向的關聯,CEC收到當前時刻gate的限制,同時又會影響下一時刻的gate。 
- input gate和forget gate的輸入增加一項s(t1) 
- output gate的輸入增加一項s(t)

peephole使得網路可以記錄更多的時序上的關聯性,有助於提取相關事件準確週期的相關資訊,可以應用於音樂韻律的分析等工作。

第四步:CTC訓練準則

ctc訓練

參考

[1]《Untersuchungen zu dynamischen neuronalen Netzen》 Hochreiter(德文的,人家的碩士論文) 
[2]《Long Short-Term Memory》 Hochreiter, Sepp; Schmidhuber 
[3]《Long Short-Term Memory in Recurrent Neural Networks》 Felix Gers 
[4]《Supervised Sequence Labelling with Recurrent Neural Networks》 Alex Graves 
[5] http://colah.github.io/posts/2015-08-Understanding-LSTMs/

相關文章