YJango的迴圈神經網路——實現LSTM

超智慧體發表於2018-08-13

介紹

描述最常用的RNN實現方式:Long-Short Term Memory(LSTM)

梯度消失和梯度爆炸

網路回憶:《迴圈神經網路——介紹》中提到迴圈神經網路用相同的方式處理每個時刻的資料。

  • 動態圖:
YJango的迴圈神經網路——實現LSTM
  • 數學公式h_t= \phi(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + {b})

設計目的:我們希望迴圈神經網路可以將過去時刻發生的狀態資訊傳遞給當前時刻的計算中。

實際問題:但普通的RNN結構卻難以傳遞相隔較遠的資訊。

  • 考慮:若只看上圖藍色箭頭線的、隱藏狀態的傳遞過程,不考慮非線性部分,那麼就會得到一個簡化的式子(1):
    • (1) h_t= W_{hh} \cdot h_{t-1}
    • 如果將起始時刻的隱藏狀態資訊h_0向第t時刻傳遞,會得到式子(2)
    • (2) h_t= (W_{hh})^t \cdot h_{0}
    • W_{hh}會被乘以多次,若允許矩陣W_{hh}進行特徵分解
    • (3) h_t= (W_{hh})^t \cdot h_{0}
    • 式子(2)會變成(4)
    • (4) h_t= Q \cdot \Lambda ^t \cdot Q^T \cdot h_{0}

當特徵值小於1時,不斷相乘的結果是特徵值的t次方向 0 衰減; 當特徵值大於1時,不斷相乘的結果是特徵值的t次方向 \infty擴增。 這時想要傳遞的h_0中的資訊會被掩蓋掉,無法傳遞到h_t

YJango的迴圈神經網路——實現LSTM
  • 類比:設想y=a^t*x,如果a等於0.1,x在被不斷乘以0.1一百次後會變成多小?如果a等於5,x在被不斷乘以5一百次後會變得多大?若想要x所包含的資訊既不消失,又不爆炸,就需要儘可能的將a的值保持在1。
  • 注:更多內容請參閱Deep Learning by Ian Goodfellow中第十章

Long Short Term Memory (LSTM)

上面的現象可能並不意味著無法學習,但是即便可以,也會非常非常的慢。為了有效的利用梯度下降法學習,我們希望使不斷相乘的梯度的積(the product of derivatives)保持在接近1的數值。

一種實現方式是建立線性自連線單元(linear self-connections)和在自連線部分數值接近1的權重,叫做leaky units。但Leaky units的線性自連線權重是手動設定或設為引數,而目前最有效的方式gated RNNs是通過gates的調控,允許線性自連線的權重在每一步都可以自我變化調節。LSTM就是gated RNNs中的一個實現。

LSTM的初步理解

LSTM(或者其他gated RNNs)是在標準RNN (h_t= \phi(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + {b}))的基礎上裝備了若干個控制數級(magnitude)的gates。可以理解成神經網路(RNN整體)中加入其他神經網路(gates),而這些gates只是控制數級,控制資訊的流動量

數學公式:這裡貼出基本LSTM的數學公式,看一眼就好,僅僅是為了讓大家先留一個印象,不需要記住,不需要理解。

YJango的迴圈神經網路——實現LSTM

儘管式子不算複雜,卻包含很多知識,接下來就是逐步分析這些式子以及背後的道理。 比如\odot的意義和使用原因,sigmoid的使用原因。

門(gate)的理解

理解Gated RNNs的第一步就是明白gate到底起到什麼作用。

  • 物理意義:gate本身可看成是十分有物理意義的一個神經網路
    • 輸入:gate的輸入是控制依據;
    • 輸出:gate的輸出是值域為(0,1)的數值,表示該如何調節其他資料的數級的控制方式。
  • 使用方式:gate所產生的輸出會用於控制其他資料的數級,相當於過濾器的作用。
    • 類比圖:可以把資訊想象成水流,而gate就是控制多少水流可以流過。
YJango的迴圈神經網路——實現LSTM
  • 例如:當用gate來控制向量\left[\begin{matrix}20 & 5& 7 & 8 \\\end{matrix}\right]時,
  • 若gate的輸出為\left[\begin{matrix}0.1 & 0.2& 0.9 & 0.5 \\\end{matrix}\right]時,原來的向量就會被對應元素相乘(element-wise)後變成:

\left[\begin{matrix}20 & 5& 7 & 8 \\\end{matrix}\right]\odot \left[\begin{matrix}0.1 & 0.2& 0.9 & 0.5 \\\end{matrix}\right]=\left[\begin{matrix}20*0.1 & 5*0.2& 7*0.9 & 8*0.5 \\\end{matrix}\right]=\left[\begin{matrix}2 & 1& 6.3 & 4 \\\end{matrix}\right]

  • 若gate的輸出為\left[\begin{matrix}0.5 & 0.5& 0.5 & 0.5 \\\end{matrix}\right]時,原來的向量就會被對應元素相乘(element-wise)後變成:

\left[\begin{matrix}20 & 5& 7 & 8 \\\end{matrix}\right]\odot \left[\begin{matrix}0.5 & 0.5& 0.5 & 0.5 \\\end{matrix}\right]=\left[\begin{matrix}10 & 2.5& 3.5 & 4 \\\end{matrix}\right]

  • 控制依據:明白了gate的輸出後,剩下要確定以什麼資訊為控制依據,也就是什麼是gate的輸入。
  • 例如:即便是LSTM也有很多個變種。一個變種方式是調控門的輸入。例如下面兩種gate:g= sigmoid(W_{xg} \cdot x_t + W_{hg} \cdot h_{t-1} + {b})
  • 這種gate的輸入有當前的輸入x_t和上一時刻的隱藏狀態h_{t-1}, 表示gate是將這兩個資訊流作為控制依據而產生輸出的。
  • g= sigmoid(W_{xg} \cdot x_t + W_{hg} \cdot h_{t-1} +W_{cg} \cdot c_{t-1}+ {b})
  • 這種gate的輸入有當前的輸入x_t 和上一時刻的隱藏狀態h_{t-1},以及上一時刻的cell狀態c_{t-1}, 表示gate是將這三個資訊流作為控制依據而產生輸出的。這種方式的LSTM叫做peephole connections。

LSTM的再次理解

明白了gate之後再回過頭來看LSTM的數學公式

數學公式

YJango的迴圈神經網路——實現LSTM
  • gates:先將前半部分的三個式子i_t,f_t,o_t統一理解。在LSTM中,網路首先構建了3個gates來控制資訊的流通量。
  • 注: 雖然gates的式子構成方式一樣,但是注意3個gates式子Wb的下角標並不相同。它們有各自的物理意義,在網路學習過程中會產生不同的權重
  • 有了這3個gates後,接下來要考慮的就是如何用它們裝備在普通的RNN上來控制資訊流,而根據它們所用於控制資訊流通的地點不同,它們又被分為:
    • 輸入門i_t:控制有多少資訊可以流入memory cell(第四個式子c_t)。
    • 遺忘門f_t:控制有多少上一時刻的memory cell中的資訊可以累積到當前時刻的memory cell中。
    • 輸出門o_t:控制有多少當前時刻的memory cell中的資訊可以流入當前隱藏狀態h_t中。
    • 注:gates並不提供額外資訊,gates只是起到限制資訊的量的作用。因為gates起到的是過濾器作用,所以所用的啟用函式是sigmoid而不是tanh。
  • 資訊流:資訊流的來源只有三處,當前的輸入x_t,上一時刻的隱藏狀態h_{t-1},上一時刻的cell狀態c_{t-1},其中c_{t-1}是額外製造出來、可線性自連線的單元(請回想起leaky units)。真正的資訊流來源可以說只有當前的輸入x_t,上一時刻的隱藏狀態h_{t-1}兩處。三個gates的控制依據,以及資料的更新都是來源於這兩處。
  • 分析了gates和資訊流後,再分析剩下的兩個等式,來看LSTM是如何累積歷史資訊和計算隱藏狀態h的。
  • 歷史資訊累積:
    • 式子:c _t = f_t \odot c_{t - 1} + i_t \odot tanh(W_{xc} x_t + W_{hc}h_{t-1} + b_c)
    • 其中new=tanh(W_{xc} x_t + W_{hc}h_{t-1} + b_c)是本次要累積的資訊來源。
    • 改寫:c _t = f_t \odot c_{t - 1} + i_t \odot new

所以歷史資訊的累積是並不是靠隱藏狀態h自身,而是依靠memory cell這個自連線來累積。 在累積時,靠遺忘門來限制上一時刻的memory cell的資訊,並靠輸入門來限制新資訊。並且真的達到了leaky units的思想,memory cell的自連線是線性的累積。

  • 當前隱藏狀態的計算:如此大費周章的最終任然是同普通RNN一樣要計算當前隱藏狀態。
    • 式子:h_t = o_t \odot tanh(c_t)
    • 當前隱藏狀態h_t是從c_t計算得來的,因為c_t是以線性的方式自我更新的,所以先將其加入帶有非線性功能的tanh(c_t)。 隨後再靠輸出門o_t的過濾來得到當前隱藏狀態h_t

普通RNN與LSTM的比較

下面為了加深理解迴圈神經網路的核心,再來和YJango一起比較一下普通RNN和LSTM的區別。

  • 比較公式:最大的區別是多了三個神經網路(gates)來控制資料的流通。
    • 普通RNN:h_t= tanh(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + {b})
    • LSTM:h _t = o_t \odot tanh(f_t \odot c_{t - 1} + i_t \odot tanh(W_{xc} x_t + W_{hc}h_{t-1} + b_c))
    • 比較:二者的資訊來源都是tanh(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + {b}) 不同的是LSTM靠3個gates將資訊的積累建立線上性自連線的memory cell之上,並靠其作為中間物來計算當前h_t
  • 示圖比較:圖片來自Understanding LSTM,強烈建議一併閱讀。
    • 普通RNN:
YJango的迴圈神經網路——實現LSTM
  • LSTM:加號圓圈表示線性相加,乘號圓圈表示用gate來過濾資訊。
YJango的迴圈神經網路——實現LSTM
  • 比較:新資訊從黃色的tanh處,線性累積到memory cell之中後,又從紅色的tanh處加入非線性並返回到了隱藏狀態h_t的計算中。
LSTM靠3個gates將資訊的積累建立線上性自連線的權重接近1的memory cell之上,並靠其作為中間物來計算當前h_t

LSTM的類比

對於用LSTM來實現RNN的記憶,可以類比我們所用的手機(僅僅是為了方便記憶,並非一一對應)。

YJango的迴圈神經網路——實現LSTM

普通RNN好比是手機螢幕,而LSTM-RNN好比是手機膜。

大量非線性累積歷史資訊會造成梯度消失(梯度爆炸)好比是不斷使用後容易使螢幕刮花。

而LSTM將資訊的積累建立線上性自連線的memory cell之上,並靠其作為中間物來計算當前h_t好比是用手機螢幕膜作為中間物來觀察手機螢幕。

輸入門、遺忘門、輸出門的過濾作用好比是手機螢幕膜的反射率、吸收率、透射率三種性質。

Gated RNNs的變種

需要再次明確的是,神經網路之所以被稱之為網路是因為它可以非常自由的建立合理的連線。而上面所介紹的LSTM也只是最基本的LSTM。只要遵守幾個關鍵點,讀者可以根據需求設計自己的Gated RNNs,而至於在不同任務上的效果需要通過實驗去驗證。下面就簡單介紹YJango所理解的幾個Gated RNNs的變種的設計方向。

  • 資訊流:標準的RNN的資訊流有兩處:input輸入和hidden state隱藏狀態。

但往往資訊流並非只有兩處,即便是有兩處,也可以拆分成多處,並通過明確多處資訊流之間的結構關係來加入先驗知識,減少訓練所需資料量,從而提高網路效果。

例如:Tree-LSTM在具有此種結構的自然語言處理任務中的應用。

YJango的迴圈神經網路——實現LSTM
  • gates的控制方式:與LSTM一樣有名的是Gated Recurrent Unit (GRU),而GRU使用gate的方式就與LSTM的不同,GRU只用了兩個gates,將LSTM中的輸入門和遺忘門合併成了更新門。並且並不把線性自更新建立在額外的memory cell上,而是直接線性累積建立在隱藏狀態上,並靠gates來調控。
YJango的迴圈神經網路——實現LSTM
  • gates的控制依據:上文所介紹的LSTM中的三個gates所使用的控制依據都是W x_t + Wh_{t-1},但是可以通過與memory cell的連線來增加控制依據或者刪除某個gate的W x_tWh_{t-1}來縮減控制依據。比如去掉上圖中z_t=sigmoid(W_z\cdot [h_{t-1},x_t])中的h_{t-1}從而變成z_t=sigmoid(W_z\cdot h_{t-1})

相關文章