LSTM模型與前向反向傳播演算法

劉建平Pinard發表於2017-03-08

    在迴圈神經網路(RNN)模型與前向反向傳播演算法中,我們總結了對RNN模型做了總結。由於RNN也有梯度消失的問題,因此很難處理長序列的資料,大牛們對RNN做了改進,得到了RNN的特例LSTM(Long Short-Term Memory),它可以避免常規RNN的梯度消失,因此在工業界得到了廣泛的應用。下面我們就對LSTM模型做一個總結。

1. 從RNN到LSTM

    在RNN模型裡,我們講到了RNN具有如下的結構,每個序列索引位置t都有一個隱藏狀態$h^{(t)}$。

    如果我們略去每層都有的$o^{(t)}, L^{(t)}, y^{(t)}$,則RNN的模型可以簡化成如下圖的形式:

    圖中可以很清晰看出在隱藏狀態$h^{(t)}$由$x^{(t)}$和$h^{(t-1)}$得到。得到$h^{(t)}$後一方面用於當前層的模型損失計算,另一方面用於計算下一層的$h^{(t+1)}$。

    由於RNN梯度消失的問題,大牛們對於序列索引位置t的隱藏結構做了改進,可以說通過一些技巧讓隱藏結構複雜了起來,來避免梯度消失的問題,這樣的特殊RNN就是我們的LSTM。由於LSTM有很多的變種,這裡我們以最常見的LSTM為例講述。LSTM的結構如下圖:

    可以看到LSTM的結構要比RNN的複雜的多,真佩服牛人們怎麼想出來這樣的結構,然後這樣居然就可以解決RNN梯度消失的問題?由於LSTM怎麼可以解決梯度消失是一個比較難講的問題,我也不是很熟悉,這裡就不多說,重點回到LSTM的模型本身。

 2. LSTM模型結構剖析

    上面我們給出了LSTM的模型結構,下面我們就一點點的剖析LSTM模型在每個序列索引位置t時刻的內部結構。

    從上圖中可以看出,在每個序列索引位置t時刻向前傳播的除了和RNN一樣的隱藏狀態$h^{(t)}$,還多了另一個隱藏狀態,如圖中上面的長橫線。這個隱藏狀態我們一般稱為細胞狀態(Cell State),記為$C^{(t)}$。如下圖所示:

    除了細胞狀態,LSTM圖中還有了很多奇怪的結構,這些結構一般稱之為門控結構(Gate)。LSTM在在每個序列索引位置t的門一般包括遺忘門,輸入門和輸出門三種。下面我們就來研究上圖中LSTM的遺忘門,輸入門和輸出門以及細胞狀態。

 2.1 LSTM之遺忘門

    遺忘門(forget gate)顧名思義,是控制是否遺忘的,在LSTM中即以一定的概率控制是否遺忘上一層的隱藏細胞狀態。遺忘門子結構如下圖所示:

    圖中輸入的有上一序列的隱藏狀態$h^{(t-1)}$和本序列資料$x^{(t)}$,通過一個啟用函式,一般是sigmoid,得到遺忘門的輸出$f^{(t)}$。由於sigmoid的輸出$f^{(t)}$在[0,1]之間,因此這裡的輸出f^{(t)}代表了遺忘上一層隱藏細胞狀態的概率。用數學表示式即為:$$f^{(t)} = \sigma(W_fh^{(t-1)} + U_fx^{(t)} + b_f)$$

    其中$W_f, U_f, b_f$為線性關係的係數和偏倚,和RNN中的類似。$\sigma$為sigmoid啟用函式。

2.2 LSTM之輸入門

    輸入門(input gate)負責處理當前序列位置的輸入,它的子結構如下圖:

    從圖中可以看到輸入門由兩部分組成,第一部分使用了sigmoid啟用函式,輸出為$i^{(t)}$,第二部分使用了tanh啟用函式,輸出為$a^{(t)}$, 兩者的結果後面會相乘再去更新細胞狀態。用數學表示式即為:$$i^{(t)} = \sigma(W_ih^{(t-1)} + U_ix^{(t)} + b_i)$$$$a^{(t)} =tanh(W_ah^{(t-1)} + U_ax^{(t)} + b_a)$$

    其中$W_i, U_i, b_i, W_a, U_a, b_a,$為線性關係的係數和偏倚,和RNN中的類似。$\sigma$為sigmoid啟用函式。

2.3 LSTM之細胞狀態更新

    在研究LSTM輸出門之前,我們要先看看LSTM之細胞狀態。前面的遺忘門和輸入門的結果都會作用於細胞狀態$C^{(t)}$。我們來看看從細胞狀態$C^{(t-1)}$如何得到$C^{(t)}$。如下圖所示:

    細胞狀態$C^{(t)}$由兩部分組成,第一部分是$C^{(t-1)}$和遺忘門輸出$f^{(t)}$的乘積,第二部分是輸入門的$i^{(t)}$和$a^{(t)}$的乘積,即:$$C^{(t)} = C^{(t-1)} \odot f^{(t)} + i^{(t)} \odot a^{(t)}$$

    其中,$\odot$為Hadamard積,在DNN中也用到過。

2.4 LSTM之輸出門

    有了新的隱藏細胞狀態$C^{(t)}$,我們就可以來看輸出門了,子結構如下:

    從圖中可以看出,隱藏狀態$h^{(t)}$的更新由兩部分組成,第一部分是$o^{(t)}$, 它由上一序列的隱藏狀態$h^{(t-1)}$和本序列資料$x^{(t)}$,以及啟用函式sigmoid得到,第二部分由隱藏狀態$C^{(t)}$和tanh啟用函式組成, 即:$$o^{(t)} = \sigma(W_oh^{(t-1)} + U_ox^{(t)} + b_o)$$$$h^{(t)} = o^{(t)} \odot tanh(C^{(t)})$$

    通過本節的剖析,相信大家對於LSTM的模型結構已經有了解了。當然,有些LSTM的結構和上面的LSTM圖稍有不同,但是原理是完全一樣的。

3. LSTM前向傳播演算法

    現在我們來總結下LSTM前向傳播演算法。LSTM模型有兩個隱藏狀態$h^{(t)}, C^{(t)}$,模型引數幾乎是RNN的4倍,因為現在多了$W_f, U_f, b_f, W_a, U_a, b_a, W_i, U_i, b_i, W_o, U_o, b_o$這些引數。

    前向傳播過程在每個序列索引位置的過程為:

    1)更新遺忘門輸出:$$f^{(t)} = \sigma(W_fh^{(t-1)} + U_fx^{(t)} + b_f)$$

    2)更新輸入門兩部分輸出:$$i^{(t)} = \sigma(W_ih^{(t-1)} + U_ix^{(t)} + b_i)$$$$a^{(t)} = tanh(W_ah^{(t-1)} + U_ax^{(t)} + b_a)$$

    3)更新細胞狀態:$$C^{(t)} = C^{(t-1)} \odot f^{(t)} + i^{(t)} \odot a^{(t)}$$

    4)更新輸出門輸出:$$o^{(t)} = \sigma(W_oh^{(t-1)} + U_ox^{(t)} + b_o)$$$$h^{(t)} = o^{(t)} \odot tanh(C^{(t)})$$

    5)更新當前序列索引預測輸出:$$\hat{y}^{(t)} = \sigma(Vh^{(t)} + c)$$

4.  LSTM反向傳播演算法推導關鍵點

    有了LSTM前向傳播演算法,推導反向傳播演算法就很容易了, 思路和RNN的反向傳播演算法思路一致,也是通過梯度下降法迭代更新我們所有的引數,關鍵點在於計算所有引數基於損失函式的偏導數。

    在RNN中,為了反向傳播誤差,我們通過隱藏狀態$h^{(t)}$的梯度$\delta^{(t)}$一步步向前傳播。在LSTM這裡也類似。只不過我們這裡有兩個隱藏狀態$h^{(t)}$和$C^{(t)}$。這裡我們定義兩個$\delta$,即:$$\delta_h^{(t)} = \frac{\partial L}{\partial h^{(t)}}$$$$\delta_C^{(t)} = \frac{\partial L}{\partial C^{(t)}}$$

    為了便於推導,我們將損失函式$L(t)$分成兩塊,一塊是時刻$t$位置的損失$l(t)$,另一塊是時刻$t$之後損失$L(t+1)$,即:$$L(t) = \begin{cases} l(t) + L(t+1) & \text{if} \, t < \tau \\ l(t) & \text{if} \, t = \tau\end{cases}$$    

    而在最後的序列索引位置$\tau$的$\delta_h^{(\tau)}$和 $\delta_C^{(\tau)} $為:$$\delta_h^{(\tau)} =\frac{\partial L^{(\tau)}}{\partial O^{(\tau)}} \frac{\partial O^{(\tau)}}{\partial h^{(\tau)}} = V^T(\hat{y}^{(\tau)} - y^{(\tau)})$$$$\delta_C^{(\tau)} =\frac{\partial L^{(\tau)}}{\partial h^{(\tau)}} \frac{\partial h^{(\tau)}}{\partial C^{(\tau)}} = \delta_h^{(\tau)} \odot  o^{(\tau)} \odot (1 - tanh^2(C^{(\tau)}))$$

    接著我們由$\delta_C^{(t+1)},\delta_h^{(t+1)}$反向推導$\delta_h^{(t)}, \delta_C^{(t)}$。

    $\delta_h^{(t)}$的梯度由本層t時刻的輸出梯度誤差和大於t時刻的誤差兩部分決定,即:$$  \delta_h^{(t)} =\frac{\partial L}{\partial h^{(t)}}  =\frac{\partial l(t)}{\partial h^{(t)}} + \frac{\partial L(t+1)}{\partial h^{(t+1)}}  \frac{\partial h^{(t+1)}}{\partial h^{(t)}} = V^T(\hat{y}^{(t)} - y^{(t)}) + \delta_h^{(t+1)} \frac{\partial h^{(t+1)}}{\partial h^{(t)}}$$

    整個LSTM反向傳播的難點就在於$ \frac{\partial h^{(t+1)}}{\partial h^{(t)}}$這部分的計算。仔細觀察,由於$h^{(t)} = o^{(t)} \odot tanh(C^{(t)})$, 在第一項$o^{(t)}$中,包含一個$h$的遞推關係,第二項$tanh(C^{(t)})$就複雜了,$tanh$函式裡面又可以表示成:$$C^{(t)} = C^{(t-1)} \odot f^{(t)} + i^{(t)} \odot a^{(t)}$$

    $tanh$函式的第一項中,$f^{(t)} $包含一個$h$的遞推關係,在$tanh$函式的第二項中,$i^{(t)}$和$a^{(t)}$都包含$h$的遞推關係,因此,最終$ \frac{\partial h^{(t+1)}}{\partial h^{(t)}}$這部分的計算結果由四部分組成。即:

$$\Delta C = o^{(t+1)} \odot [1-tanh^2(C^{(t+1)})]$$

$$\frac{\partial h^{(t+1)}}{\partial h^{(t)}} = W_o^T [o^{(t+1)} \odot (1-o^{(t+1)}) \odot tanh(C^{(t+1)})] +  W_f^T [\Delta C  \odot f^{(t+1)} \odot (1-f^{(t+1)}) \odot C^{(t)}] + W_a^T \{ \Delta C  \odot i^{(t+1)} \odot [1-(a^{(t+1)})^2] \}  + W_i^T [\Delta C  \odot a^{(t+1)} \odot  i^{(t+1)}  \odot (1-i^{(t+1)})]$$

    而$\delta_C^{(t)}$的反向梯度誤差由前一層$\delta_C^{(t+1)}$的梯度誤差和本層的從$h^{(t)}$傳回來的梯度誤差兩部分組成,即:$$\delta_C^{(t)} =\frac{\partial L}{\partial C^{(t+1)}} \frac{\partial  C^{(t+1)}}{\partial C^{(t)}} + \frac{\partial L}{\partial h^{(t)}}\frac{\partial h^{(t)}}{\partial C^{(t)}} = \delta_C^{(t+1)}\odot f^{(t+1)} + \delta_h^{(t)} \odot  o^{(t)} \odot (1 - tanh^2(C^{(t)}))$$

    有了$\delta_h^{(t)}$和$\delta_C^{(t)}$, 計算這一大堆引數的梯度就很容易了,這裡只給出$W_f$的梯度計算過程,其他的$U_f, b_f, W_a, U_a, b_a, W_i, U_i, b_i, W_o, U_o, b_o,V, c$的梯度大家只要照搬就可以了。$$\frac{\partial L}{\partial W_f} =  \sum\limits_{t=1}^{\tau}\frac{\partial L}{\partial C^{(t)}} \frac{\partial C^{(t)}}{\partial f^{(t)}} \frac{\partial f^{(t)}}{\partial W_f} =\sum\limits_{t=1}^{\tau} [\delta_C^{(t)} \odot C^{(t-1)} \odot f^{(t)}\odot(1-f^{(t)})] (h^{(t-1)})^T$$

5. LSTM小結

    LSTM雖然結構複雜,但是隻要理順了裡面的各個部分和之間的關係,進而理解前向反向傳播演算法是不難的。當然實際應用中LSTM的難點不在前向反向傳播演算法,這些有演算法庫幫你搞定,模型結構和一大堆引數的調參才是讓人頭痛的問題。不過,理解LSTM模型結構仍然是高效使用的前提。

(歡迎轉載,轉載請註明出處。歡迎溝通交流: liujianping-ok@163.com)

參考資料:

1) Neural Networks and Deep Learning by By Michael Nielsen

2) Deep Learning, book by Ian Goodfellow, Yoshua Bengio, and Aaron Courville

3) UFLDL Tutorial

4)Understanding-LSTMs

\Delta C = o^{(t+1)} \odot [1-tanh^2(C^{(t+1)})]

相關文章