RNN-迴圈神經網路和LSTM_01基礎

CopperDong發表於2018-05-27

一、介紹

1、什麼是RNN

  • 傳統的神經網路是層與層之間是全連線的,但是每層之間的神經元是沒有連線的(其實是假設各個資料之間是獨立的
    • 這種結構不善於處理序列化的問題。比如要預測句子中的下一個單詞是什麼,這往往與前面的單詞有很大的關聯,因為句子裡面的單詞並不是獨立的。
  • RNN 的結構說明當前的的輸出與前面的輸出也有關,即隱層之間的節點不再是無連線的,而是有連線的
    • 基本的結構如圖,可以看到有個迴圈的結構,將其展開就是右邊的結構

RNN基本結構

2、運算說明

  • 如上圖,輸入單元(inputs units): {x0,x1,,xt,xt+1,}{x0,x1,⋯⋯,xt,xt+1,⋯⋯},
    • 輸出單元(output units)為:{o0,o1,,ot,ot+1,}{o0,o1,⋯⋯,ot,ot+1,⋯⋯},
    • 隱藏單元(hidden units)輸出集: {s0,s1,,ost,st+1,}{s0,s1,⋯⋯,ost,st+1,⋯⋯}
  • 時間 t 隱層單元的輸出為:st=f(Uxt+Wst1)st=f(Uxt+Wst−1)
    • f就是激勵函式,一般是sigmoid,tanh, relu
    • 計算s0s0時,即第一個的隱藏層狀態,需要用到s1s−1,但是其並不存在,在實現中一般置為0向量
    • (如果將上面的豎著立起來,其實很像傳統的神經網路,哈哈)
  • 時間 t 的輸出為:ot=Softmax(Vst)ot=Softmax(Vst)
    • 可以認為隱藏層狀態stst網路的記憶單元stst包含了前面所有步的隱藏層狀態。而輸出層的輸出otot只與當前步的stst有關。
    • (在實踐中,為了降低網路的複雜度,往往stst只包含前面若干步而不是所有步的隱藏層狀態)
  • RNNs中,每輸入一步,每一層都共享引數U,V,W,(因為是將迴圈的部分展開,天然應該相等)
  • RNNs的關鍵之處在於隱藏層,隱藏層能夠捕捉序列的資訊。

3、應用方面

機器翻譯

(3) 語音識別

二、結構

1、One to One

  • 即一個輸入對應一個輸出,就是上面的圖

    2、Many to One

  • 即多個輸入對應一個輸出,比如情感分析,一段話中很多次,判斷這段話的情感
  • 其中x1,x2,,xtx1,x2,…,xt表示句子中的t個詞,o代表最終輸出的情感標籤
  • 前向計算就是:
    f(x)=Vst=V(Uxt+Wst1)=V(Uxt+W(Uxt1+Wst2))f(x)=Vst=V(Uxt+Wst−1)=V(Uxt+W(Uxt−1+Wst−2))⋯

    Many to one

    3、One to Many

  • 前向計算類似,不再給出
    One to Many

    4、Many to Many

  • 前向計算類似,不再給出
    Many to Many

    5、雙向RNN(Bidirectional RNN)

  • 比如翻譯問題往往需要聯絡上下文內容才能正確的翻譯,我們上面的結構線性傳遞允許“聯絡上文”,但是聯絡下文並沒有,所以就有雙向RNN
  • 前向運算稍微複雜一點,以t時刻為例
    ot=W(os)tst+W(oh)tht=W(os)t(W(ss)t1st1+W(sx)txt1)+W(oh)t(W(hh)tht+1+W(hx)txt)ot=Wt(os)st+Wt(oh)ht=Wt(os)(Wt−1(ss)st−1+Wt(sx)xt−1)+Wt(oh)(Wt(hh)ht+1+Wt(hx)xt)
    RNN

6、深層的RNN

  • 上面的結構都是隻含有一層的state層,根據傳統NN和CNN,深層次的結構有更加號的效果,結構如圖
    深層的RNN

三、Back Propagation Through Time(BPTT)訓練

  • 關於傳統神經網路BP演算法可以檢視這裡神經網路部分的推導

    1、符號等說明

  • 以下圖為例

RNN基本結構

  • 符號說明
    • ϕϕ………………………………………………隱藏層的激勵函式
    • φφ………………………………………………輸出層的變換函式
    • Lt=Lt(ot,yt)Lt=Lt(ot,yt)……………………………模型的損失函式
      • 標籤資料ytyt是一個 one-hot 向量

2、反向傳播過程

  • 接受完序列中所有樣本後再統一計算損失,此時模型的總損失可以表示為(假設輸入序列長度為n):
    L=t=1nLtL=∑t=1nLt

    RNN
  • ot=φ(Vst)=φ(V(Uxt+Wst1))ot=φ(Vst)=φ(V(Uxt+Wst−1))
    • 其中s0=0=(0,0,,0)Ts0=0=(0,0,…,0)T
  • 令:ot=Vst,st=Uxt+Wst1(1)ot∗=Vst,st∗=Uxt+Wst−1…………(1) (就是沒有經過激勵函式和變換函式前)
    • 則:ot=φ(ot)ot=φ(ot∗)
    • st=ϕ(st)st=ϕ(st∗)

(1) 矩陣V的更新

  • 矩陣 V 的更新過程,根據(1)式可得, (和傳統的神經網路一致,根據求導的鏈式法則):
    • Ltot=Ltototot=Ltotφ(ot)∂Lt∂ot∗=∂Lt∂ot∗∂ot∂ot∗=∂Lt∂ot∗φ′(ot∗)
    • LtV=LtVstVstV=Ltot×sTt=(Ltotφ(ot))×sTt∂Lt∂V=∂Lt∂Vst∗∂Vst∂V=∂Lt∂ot∗×stT=(∂Lt∂ot∗φ′(ot∗))×stT
      - 因為L=nt=1LtL=∑t=1nLt,所以對矩陣V的更新對應的導數:
      LV=t=1n(Ltotφ(ot))×sTt∂L∂V=∑t=1n(∂Lt∂ot∗φ′(ot∗))×stT

(2) 矩陣U和W的更新

  • RNN 的 BP 演算法的主要難點在於它 State 之間的通訊
  • 可以採用迴圈的方法來計算各個梯度,t應從n開始降序迴圈至 1
  • 計算時間通道上的區域性梯度(同樣根據鏈式法則)
    Ltst=LtVst×sTtVTtststst=VT×(Ltotφ(ot))∂Lt∂st∗=∂Lt∂Vst×∂stTVtT∂st∗∂st∂st∗=VT×(∂Lt∂ot∗φ‘(ot∗))

Ltsk1=sksk1×Ltsk=WT×(Ltskϕ(sk1)),(k=1,,t)(2)∂Lt∂sk−1∗=∂sk∗∂sk−1∗×∂Lt∂sk∗=WT×(∂Lt∂sk∗∗ϕ′(sk−1∗)),(k=1,……,t)………(2)

  • 利用區域性梯度計算UW的梯度
    • 這裡累加是因為權值是共享的,所以往前推算一直用的是一樣的權值
      LtU+=k=1tLtsk×skU=k=1tLtsk×xTt∂Lt∂U+=∑k=1t∂Lt∂sk∗×∂sk∗∂U=∑k=1t∂Lt∂sk∗×xtT

      LtW+=k=1tLtsk×skW=k=1tLtsk×sTt1..(3)∂Lt∂W+=∑k=1t∂Lt∂sk∗×∂sk∗∂W=∑k=1t∂Lt∂sk∗×st−1T………………..(3)

3、訓練問題

  •  公式(2)和(3) 中可以看出,時間維度上的權重W更新需要計算ϕ(sk)ϕ‘(sk∗),即經過激勵函式的導數
  • 如果時間維度上很長,則這個梯度是累積的,所以造成梯度消失或爆炸
    • 可以想象將結構圖豎起來,就是一個深層的神經網路,所以容易出現梯度問題
    • 關於梯度消失的問題可以檢視我這裡一遍部落格
  • RNN 主要的作用就是能夠記住之前的資訊,但是梯度消失的問題又告訴我們不能記住太久之前的資訊,改進的思路有兩點
    • 一是使用一些trick,比如合適的激勵函式,初始化,BN等等
    • 二是改進state的傳遞方式,比如就是下面提及的LSTM
      • 關於為何 LSTMs 能夠解決梯度消失,直觀上來說就是上方時間通道是簡單的線性組合

四、Long Short-Term Memory(LSTM,長短時記憶網路)

1、介紹

  • LSTM 是一般 RNN 的升級,因為一些序列問題,我們可能需要忘記一些東西, LSTM 和普通 RNN 相比, 多出了三個控制器. (輸入控制, 輸出控制, 忘記控制)
  • LSTM裡,這個叫做cell(其實就是前面的state,只是這裡更加複雜了), 可以看作一個黑盒,這個cell結合前面cell的輸出ht1ht−1和當前的輸入xtxt來決定是否記憶下來,該網路結構在對長序列依賴問題中非常有效

2、結構

  • 一個經典的cell結構如下圖
    • ϕ1ϕ1sigmoid函式,ϕ2ϕ2 是tanh函式
    • *表示 element wise 乘法(就是點乘),使用X表示矩陣乘法
  • LSTMs 的 cell 的時間通道有兩條
    • 上方的時間通道(h(old)h(new)h(old)→h(new))僅包含了兩個代數運算,這意味著它資訊傳遞的方式會更為直接
      h(new)=h(old)r1+r2h(new)=h(old)∗r1+r2
    • 位於下方的時間通道(s(old)s(new)s(old)→s(new))則運用了大量的層結構,在 LSTMs 中,我們通常稱這些層結構為門(Gates

LSTM cell結構

3、運算說明

  • Sigmoid 函式取值區間為 0-1,那麼當 Sigmoid 對應的層結構輸出 0 時,就對應著遺忘這個過程;當輸出 1時,自然就對應著接受這個過程。
    • 事實上這也是 Sigmoid 層叫門的原因——它能決定“放哪些資料進來”和決定“不讓哪些資料通過”
  • 最左邊的Sigmoid gate 叫做遺忘門, 控制著時間通道資訊的遺忘程度
    • 前向計算: r1=ϕ1(W1×x)r1=ϕ1(W1×x∗)
      • 其中 x=Δ[x,s(old)]x∗=Δ[x,s(old)],表示當前輸入樣本和下方時間通道s(old)s(old)連線(concat)起來
  • 第二個 Sigmoid Gate 通常被稱為輸入門(Input Gate), 控制著當前輸入和下方通道資訊對上方通道資訊的影響
    • 前向運算為:g1=ϕ1(W2×x)g1=ϕ1(W2×x∗),
  • 第三個 Tanh Gate 則允許網路結構駁回歷史資訊, 因為tanh的值域是(-1,1)
    • 前向運算為:g2=ϕ2(W3×x)g2=ϕ2(W3×x∗)
    • r2=g1g2r2=g1∗g2
  • 第四個 Sigmoid Gate 通常被稱為輸出門(Output Gate),它為輸出和傳向下一個 cell 的下方通道資訊作出了貢獻。
    • 對應的前向傳導演算法為:g3=ϕ1(W4×x)g3=ϕ1(W4×x∗)
  • 最終cell的輸出為:o=s(new)=ϕ2(h(new))g3o=s(new)=ϕ2(h(new))∗g3
  • 每個 Gate 對應的權值矩陣是不同的(W1W4W1∼W4),切勿以為它們會共享權值

Reference

相關文章