《神經網路的梯度推導與程式碼驗證》之LSTM的前向傳播和反向梯度推導

SumwaiLiu發表於2020-09-07

前言

在本篇章,我們將專門針對LSTM這種網路結構進行前向傳播介紹和反向梯度推導。

關於LSTM的梯度推導,這一塊確實挺不好掌握,原因有:

  • 一些經典的deep learning 教程,例如花書缺乏相關的內容
  • 一些經典的論文不太好看懂,例如On the difficulty of training Recurrent Neural Networks上有LSTM的梯度推導但看得我還是一頭霧水(可能是我能力有限。。)
  • 網上關於LSTM的梯度推導雖多,但缺乏保證其正確性的驗證實驗

考慮到上述問題,本篇章將以最低限度的知識依賴進行LSTM的反向梯度推導,所有推導基礎均基於《神經網路的梯度推導與程式碼驗證》之數學基礎篇:矩陣微分與求導。為保證所得無誤,後續將通過tensorflow的自動微分工具驗證LSTM梯度推導結論的準確性。

 

更多相關內容請見《神經網路的梯度推導與程式碼驗證》系列介紹

 


 

目錄

 

提醒:

  • 後續會反覆出現$\boldsymbol{\delta}^{l}$這個(類)符號,它的定義為$\boldsymbol{\delta}^{l} = \frac{\partial l}{\partial\boldsymbol{z}^{\boldsymbol{l}}}$,即loss $l$對$\boldsymbol{z}^{\boldsymbol{l}}$的導數
  • 其中$\boldsymbol{z}^{\boldsymbol{l}}$表示第$l$層(DNN,CNN,RNN或其他例如max pooling層等)未經過啟用函式的輸出。
  • $\boldsymbol{a}^{\boldsymbol{l}}$則表示$\boldsymbol{z}^{\boldsymbol{l}}$經過啟用函式後的輸出。

這些符號會貫穿整個系列,還請留意。


 

5.1 LSTM的前向傳播

在RNN模型裡,我們講到了RNN具有如下的結構,每個序列索引位置$t$都有一個隱藏狀態$\boldsymbol{h}^{(t)}$。

 

如果我們只關注RNN的核心迴圈部分而不看$\boldsymbol{o}^{(t)}$,$\boldsymbol{L}^{(t)}$和$\boldsymbol{y}^{(t)}$,則RNN的模型可以簡化成如下圖的形式:

 

圖中可以很清晰看出在隱藏狀態$\boldsymbol{h}^{(t)}$由$\boldsymbol{x}^{(t)}$和$\boldsymbol{h}^{(t-1)}$共同得到。得到的$\boldsymbol{h}^{(t)}$方面用於當前層的模型損失計算,另一方面用於計算下一層的$\boldsymbol{h}^{(t+1)}$。

 

由於RNN梯度消失的問題,大牛們對於序列索引位置t的隱藏結構做了改進,可以說通過一些技巧讓隱藏結構複雜了起來,來避免梯度消失的問題,這樣的特殊RNN就是我們的LSTM。由於LSTM有很多的變種,這裡我們以最常見的LSTM為例講述。LSTM的結構如下圖:

 

 

 5.1.1 LSTM之細胞狀態

上面我們給出了LSTM的模型結構,下面我們就一點點的剖析LSTM模型在每個序列索引位置$t$時刻的內部結構。

 

 從上圖中可以看出,在每個序列索引位置$t$時刻向前傳播的除了和RNN一樣的隱藏狀態$\boldsymbol{h}^{(t+1)}$,還多了另一個隱藏狀態,如圖中上面的長橫線。這個隱藏狀態我們一般稱為細胞狀態(Cell State),記為$\boldsymbol{C}^{(t)}$。如下圖所示:

 

我們可以看到從$\boldsymbol{C}^{(t - 1)}$到$\boldsymbol{C}^{(t)}$,似乎經過了若干乘法和加法操作。

 

除了細胞狀態,LSTM圖中還有了很多奇怪的結構,這些結構一般稱之為門控結構(Gate)。LSTM在在每個序列索引位置t的門一般包括遺忘門,輸入門和輸出門三種。下面我們就來研究上圖中LSTM的遺忘門,輸入門和輸出門以及細胞狀態。

 

5.1.2 LSTM之遺忘門

遺忘門(forget gate)顧名思義,是控制是否遺忘的,在LSTM中即以一定的概率控制是否遺忘上一層的隱藏細胞狀態。遺忘門子結構如下圖所示:

 

 

圖中輸入的有上一序列的隱藏狀態$\boldsymbol{h}^{(t - 1)}$和$t$時刻的輸入$\boldsymbol{x}^{(t - 1)}$,通過一個啟用函式(一般是sigmoid),得到遺忘門的輸出$\boldsymbol{f}^{(t)}$:

$\boldsymbol{f}^{(t)} = \sigma\left( {\boldsymbol{W}_{f}\boldsymbol{h}^{(t - 1)} + \boldsymbol{U}_{f}\boldsymbol{x}^{(t - 1)} + \boldsymbol{b}_{f}} \right)$

 

由於sigmoid的值域介於0~1之間,所以這裡的$\boldsymbol{f}^{(t)}$表示保留上一個時間步$\boldsymbol{h}^{(t - 1)}$的多大的成分。雖然“保留”跟“遺忘門”這兩個詞是概念上相反的,但大家似乎已經習慣用遺忘門來稱呼這個$\boldsymbol{f}^{(t)}$了。

 

5.1.3 LSTM之輸入門

輸入門(input gate)負責管理當前序列位置的輸入,它的子結構如下圖:

 

輸入門$\boldsymbol{i}^{(t)}$的數學表示式為:

$\boldsymbol{i}^{(t)} = \sigma\left( {\boldsymbol{W}_{i}\boldsymbol{h}^{(t - 1)} + \boldsymbol{U}_{i}\boldsymbol{x}^{(t - 1)} + \boldsymbol{b}_{i}} \right)$

 

對比遺忘門的表示式,除了矩陣的下標發生了點改變以外,其他都一樣。

而遺忘門的控制物件則是$\boldsymbol{h}^{(t - 1)}$和$\boldsymbol{x}^{(t - 1)}$組合的產物,它的表示式如下:

$\boldsymbol{a}^{(t)} = \sigma\left( {\boldsymbol{W}_{a}\boldsymbol{h}^{(t - 1)} + \boldsymbol{U}_{a}\boldsymbol{x}^{(t - 1)} + \boldsymbol{b}_{a}} \right)$

 

5.1.4 LSTM之細胞狀態更新

在研究LSTM輸出門之前,我們要先看看LSTM之細胞狀態。前面的遺忘門和輸入門的結果都會作用於細胞狀態$\boldsymbol{C}^{(t)}$。我們來看看$\boldsymbol{C}^{(t - 1)}$是如何得到$\boldsymbol{C}^{(t)}$的:

 

細胞狀態$\boldsymbol{C}^{(t)}$由兩部分組成,第一部分是$\boldsymbol{C}^{(t - 1)}$和遺忘門$\boldsymbol{f}^{(t)}$的Hadamard積(逐元素相乘),第二部分是$\boldsymbol{a}^{(t)}$和輸入門$\boldsymbol{i}^{(t)}$的Hadamard積:

$\boldsymbol{C}^{(t)} = \boldsymbol{C}^{(t)}\bigodot\boldsymbol{f}^{(t)} + \boldsymbol{a}^{(t)}\bigodot\boldsymbol{i}^{(t)}$

 

5.1.5 LSTM之輸出門

有了新的隱藏細胞狀態$\boldsymbol{C}^{(t)}$,現在來到輸出門:

 

輸出門$\boldsymbol{o}^{(t)}$的數學表示式為:

$\boldsymbol{o}^{(t)} = \sigma\left( {\boldsymbol{W}_{o}\boldsymbol{h}^{(t - 1)} + \boldsymbol{U}_{o}\boldsymbol{x}^{(t - 1)} + \boldsymbol{b}_{o}} \right)$

 

而輸出門所控制的物件,則是$tanh\left( \boldsymbol{C}^{(t)} \right)$,兩者共同形成$t$時間步下的隱藏狀態$\boldsymbol{h}^{(t)}$:

$\boldsymbol{h}^{(t)} = \boldsymbol{o}^{(t)}\bigodot tanh\left( \boldsymbol{C}^{(t)} \right)$

 

 

5.1.6 LSTM前向傳播總結

現在我們來總結下LSTM前向傳播演算法。LSTM模型有兩個隱藏狀態$\boldsymbol{h}^{(t)}$,$\boldsymbol{C}^{(t)}$,模型引數恰好是RNN的4倍整。

 

前向傳播過程在每個時間步$t$上發生的順序為:

1)更新遺忘門輸出:

$\boldsymbol{f}^{(t)} = \sigma\left( {\boldsymbol{W}_{f}\boldsymbol{h}^{(t - 1)} + \boldsymbol{U}_{f}\boldsymbol{x}^{(t)} + \boldsymbol{b}_{f}} \right)$

2)更新輸入門和其控制物件:

$\boldsymbol{i}^{(t)} = \sigma\left( {\boldsymbol{W}_{i}\boldsymbol{h}^{(t - 1)} + \boldsymbol{U}_{i}\boldsymbol{x}^{(t)} + \boldsymbol{b}_{i}} \right)$

$\boldsymbol{a}^{(t)} = tanh\left( {\boldsymbol{W}_{a}\boldsymbol{h}^{(t - 1)} + \boldsymbol{U}_{a}\boldsymbol{x}^{(t)} + \boldsymbol{b}_{a}} \right)$

3)更新細胞狀態,從而$\left. \boldsymbol{C}^{(t - 1)}\longrightarrow\boldsymbol{C}^{(t)} \right.$:

$\boldsymbol{C}^{(t)} = \boldsymbol{C}^{(t - 1)}\bigodot\boldsymbol{f}^{(t)} + \boldsymbol{a}^{(t)}\bigodot\boldsymbol{i}^{(t)}$

4)更新輸出門和其控制物件,從而$\left. \boldsymbol{h}^{(t - 1)}\longrightarrow\boldsymbol{h}^{(t)} \right.$:

$\boldsymbol{o}^{(t)} = \sigma\left( {\boldsymbol{W}_{o}\boldsymbol{h}^{(t - 1)} + \boldsymbol{U}_{o}\boldsymbol{x}^{(t - 1)} + \boldsymbol{b}_{o}} \right)$

$\boldsymbol{h}^{(t)} = \boldsymbol{o}^{(t)}\bigodot tanh\left( \boldsymbol{C}^{(t)} \right)$

5)得到當前時間步$t$的預測輸出:

${\hat{\boldsymbol{y}}}^{(t)} = \sigma\left( {\boldsymbol{V}\boldsymbol{h}^{(t)} + \boldsymbol{c}} \right)$ 

 


 

5.2 LSTM的反向梯度推導

在RNN中,為了計算反向傳播誤差,我們通過隱藏狀態$\boldsymbol{h}^{(t)}$的梯度$\boldsymbol{\delta}^{(t)}$一步一步向前傳播。在LSTM中也類似,只不過我們這裡由兩種隱藏狀態$\boldsymbol{h}^{(t)}$和$\boldsymbol{C}^{(t)}$,這裡我們定義兩種$\boldsymbol{\delta}$:

 $\boldsymbol{\delta}_{h}^{(t)} = \frac{\partial L}{\partial\boldsymbol{h}^{(t)}}$

$\boldsymbol{\delta}_{C}^{(t)} = \frac{\partial L}{\partial\boldsymbol{C}^{(t)}}$

 

為了方便找到梯度的遞推模式,下面是根據前向傳播公式給出資料在LSTM中資料的前向流動示意圖:

對於$t = T$,即時間序列截止的那個時間步,我們可以得到:

 $\boldsymbol{\delta}_{h}^{(T)} = \boldsymbol{V}^{T}\left( {{\hat{\boldsymbol{y}}}^{(T)} - \boldsymbol{y}^{(T)}} \right)$

$\boldsymbol{\delta}_{C}^{(T)} = \left( \frac{\partial\boldsymbol{h}^{(T)}}{\partial\boldsymbol{C}^{(T)}} \right)^{T}\frac{\partial L}{\partial\boldsymbol{h}^{(T)}} = \boldsymbol{\delta}_{h}^{(T)}\bigodot\boldsymbol{o}^{(T)}\bigodot{tanh}^{'}\left( \boldsymbol{C}^{(T)} \right)$

第一個式子的證明見vanilla RNN的前向傳播和反向梯度推導  的4.2節;第二個式子根據等式$\boldsymbol{h}^{(t)} = \boldsymbol{o}^{(t)}\bigodot tanh\left( \boldsymbol{C}^{(t)} \right)$結合數學基礎篇:矩陣微分與求導的理論即可秒證出來。

 

對於$t < T$時,我們要利用$\boldsymbol{\delta}_{h}^{(t + 1)}$和$\boldsymbol{\delta}_{C}^{(t + 1)}$遞推得到$\boldsymbol{\delta}_{h}^{(t)}$和$\boldsymbol{\delta}_{C}^{(t)}$。

 

先來推導$\boldsymbol{\delta}_{h}^{(t)}$的遞推公式:

根據上圖我們知道,$\boldsymbol{\delta}_{h}^{(t)}$的誤差來源如下:

1)$\left. l\left( t \right)\longrightarrow\boldsymbol{h}^{(t)} \right.$

2)$\left. \boldsymbol{h}^{(t + 1)}\longrightarrow\boldsymbol{o}^{(t + 1)}\longrightarrow\boldsymbol{h}^{(t)} \right.$

3)$\left. \boldsymbol{C}^{(t + 1)}\longrightarrow\boldsymbol{i}^{(t + 1)}\longrightarrow\boldsymbol{h}^{(t)} \right.$

4)$\left. \boldsymbol{C}^{(t + 1)}\longrightarrow\boldsymbol{a}^{(t + 1)}\longrightarrow\boldsymbol{h}^{(t)} \right.$

5)$\left. \boldsymbol{C}^{(t + 1)}\longrightarrow\boldsymbol{f}^{(t + 1)}\longrightarrow\boldsymbol{h}^{(t)} \right.$

 

根據鏈式法則和全微分方程,有:

$\boldsymbol{\delta}_{h}^{(t)} = \frac{\partial L\left( t \right)}{\partial\boldsymbol{h}^{(t)}} = \frac{\partial l\left( t \right)}{\partial\boldsymbol{h}^{(t)}} + \left( \frac{\partial\boldsymbol{C}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}} \right)^{T}\boldsymbol{\delta}_{C}^{(t + 1)} + \left( {\frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{o}^{(t)}}\frac{\partial\boldsymbol{o}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}}} \right)^{T}\boldsymbol{\delta}_{h}^{(t + 1)}$

 

注意:上式中特地用了$\frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{o}^{(t)}}\frac{\partial\boldsymbol{o}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}}$而不是$\frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}}$。因為在$\boldsymbol{h}^{(t + 1)}$與$\boldsymbol{h}^{(t)}$之間存在多條傳播路徑的情況下,$\frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{o}^{(t)}}\frac{\partial\boldsymbol{o}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}} \neq \frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}}$。我們用$\frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{o}^{(t)}}\frac{\partial\boldsymbol{o}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}}$規定了從$\boldsymbol{h}^{(t + 1)}$到$\boldsymbol{h}^{(t)}$的誤差傳播路徑必須是$\left. \boldsymbol{h}^{(t + 1)}\longrightarrow\boldsymbol{o}^{(t + 1)}\longrightarrow\boldsymbol{h}^{(t)} \right.$而不是其他的路徑。如果是用$\frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}}$這個符號,則是預設要考慮所有從$\boldsymbol{h}^{(t + 1)}$到$\boldsymbol{h}^{(t)}$的誤差傳播路徑。

 

上面這個遞推公式需要解決三個問題,$\frac{\partial l\left( t \right)}{\partial\boldsymbol{h}^{(t)}}$,$\left( \frac{\partial\boldsymbol{C}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}} \right)^{T}$和$\left( {\frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{o}^{(t)}}\frac{\partial\boldsymbol{o}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}}} \right)^{T}$的求解。

 

對於$\frac{\partial l\left( t \right)}{\partial\boldsymbol{h}^{(t)}}$,根據vanilla RNN的前向傳播和反向梯度推導  的4.2節,它滿足:

$\frac{\partial l\left( t \right)}{\partial\boldsymbol{h}^{(t)}} = \boldsymbol{V}^{T}\left( {{\hat{\boldsymbol{y}}}^{(t)} - \boldsymbol{y}^{(t)}} \right)$

 

我們接下來求$\frac{\partial\boldsymbol{C}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}}$:

注意:因為下面的公式實在太長了,所以為節省空間,我們用“~”表示這個位置原本的數學表示式與上一行相同位置的數學表示式一樣。

基於$\boldsymbol{C}^{(t)} = \boldsymbol{C}^{(t - 1)}\bigodot\boldsymbol{f}^{(t)} + \boldsymbol{a}^{(t)}\bigodot\boldsymbol{i}^{(t)}$逐層展開,我們得到:

$d\boldsymbol{C}^{(t + 1)} = \boldsymbol{C}^{(t)}\bigodot d\boldsymbol{f}^{({t + 1})} + \boldsymbol{i}^{({t + 1})}\bigodot d\boldsymbol{a}^{({t + 1})} + \boldsymbol{a}^{({t + 1})}\bigodot d\boldsymbol{i}^{({t + 1})}$

$= diag\left( \boldsymbol{C}^{(t)} \right)d\boldsymbol{f}^{({t + 1})} + diag\left( \boldsymbol{i}^{({t + 1})} \right)d\boldsymbol{a}^{({t + 1})} + diag\left( \boldsymbol{a}^{({t + 1})} \right)d\boldsymbol{i}^{({t + 1})}$

$= diag\left( \boldsymbol{C}^{(t)} \right)d\boldsymbol{f}^{({t + 1})} + diag\left( \boldsymbol{a}^{({t + 1})} \right)d\boldsymbol{i}^{({t + 1})} + diag\left( \boldsymbol{i}^{({t + 1})} \right)d\boldsymbol{a}^{({t + 1})}$

$\left. = diag\left( {\boldsymbol{C}^{(t)}\bigodot\boldsymbol{f}^{({t + 1})}\bigodot\left( {1 - \boldsymbol{f}^{({t + 1})}} \right)} \right)\boldsymbol{W}_{f}d\boldsymbol{h}^{(t)} + \right.\sim\left. + \right.\sim$

$\left. = \right.\sim\left. + diag\left( {\boldsymbol{a}^{({t + 1})}\bigodot\boldsymbol{i}^{({t + 1})}\bigodot\left( {1 - \boldsymbol{i}^{({t + 1})}} \right)} \right)\boldsymbol{W}_{i}d\boldsymbol{h}^{(t)} + \right.\sim$

 

因為${tanh}^{'}\left( x \right) = \left( {1 - {tanh\left( x \right)}^{2}} \right)$,所以:

$\left. d\boldsymbol{C}^{(t + 1)} = \right.\sim\left. + \right.\sim + diag\left( {\boldsymbol{i}^{({t + 1})}\bigodot\left( {1 - {\boldsymbol{a}^{({t + 1})}}^{2}} \right)} \right)\boldsymbol{W}_{a}d\boldsymbol{h}^{(t)}$

 

整理上式我們得到:

$\frac{\partial\boldsymbol{C}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}} = diag\left( {\boldsymbol{C}^{(t)}\bigodot\boldsymbol{f}^{({t + 1})}\bigodot\left( {1 - \boldsymbol{f}^{({t + 1})}} \right)} \right)\boldsymbol{W}_{f} + diag\left( {\boldsymbol{a}^{({t + 1})}\bigodot\boldsymbol{i}^{({t + 1})}\bigodot\left( {1 - \boldsymbol{i}^{({t + 1})}} \right)} \right)\boldsymbol{W}_{i} + diag\left( {\boldsymbol{i}^{({t + 1})}\bigodot\left( {1 - {\boldsymbol{a}^{({t + 1})}}^{2}} \right)} \right)\boldsymbol{W}_{a}$

 

接下來是$\frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{o}^{(t)}}\frac{\partial\boldsymbol{o}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}}$的推導過程:

$d\boldsymbol{h}^{({t + 1})} = tanh\left( \boldsymbol{C}^{({t + 1})} \right)\bigodot d\boldsymbol{o}^{({t + 1})} = diag\left( {tanh\left( \boldsymbol{C}^{({t + 1})} \right)} \right)diag\left( {\boldsymbol{o}^{({t + 1})}\bigodot\left( {1 - \boldsymbol{o}^{({t + 1})}} \right)} \right)d\left( {\boldsymbol{W}_{o}\boldsymbol{h}^{(t)}} \right) = diag\left( {tanh\left( \boldsymbol{C}^{({t + 1})} \right)\bigodot\boldsymbol{o}^{({t + 1})}\bigodot\left( {1 - \boldsymbol{o}^{({t + 1})}} \right)} \right)\boldsymbol{W}_{o}d\boldsymbol{h}^{(t)}$

 

所以$\frac{\partial\boldsymbol{h}^{(t + 1)}}{\partial\boldsymbol{o}^{(t)}}\frac{\partial\boldsymbol{o}^{(t + 1)}}{\partial\boldsymbol{h}^{(t)}} = diag\left( {tanh\left( \boldsymbol{C}^{({t + 1})} \right)\bigodot\boldsymbol{o}^{({t + 1})}\bigodot\left( {1 - \boldsymbol{o}^{({t + 1})}} \right)} \right)$

 

於是我們現在得到了從$\boldsymbol{\delta}_{C}^{(t + 1)}$和$\boldsymbol{\delta}_{h}^{(t + 1)}$推得$\boldsymbol{\delta}_{h}^{(t)}$的遞推公式。

 


 

接下來我們利用$\boldsymbol{\delta}_{h}^{(t)}$和$\boldsymbol{\delta}_{C}^{(t + 1)}$來推得$\boldsymbol{\delta}_{C}^{(t)}$:

根據LSTM的前向示意圖,我們有:

 $\boldsymbol{\delta}_{C}^{(t)} = \left( \frac{\partial\boldsymbol{h}^{(t)}}{\partial\boldsymbol{c}^{(t)}} \right)^{T}\boldsymbol{\delta}_{h}^{(t)} + {\left( \frac{\partial\boldsymbol{c}^{(t + 1)}}{\partial\boldsymbol{c}^{(t)}} \right)^{T}\boldsymbol{\delta}}_{C}^{(t + 1)}$

 

容易求得$\frac{\partial\boldsymbol{h}^{(t)}}{\partial\boldsymbol{c}^{(t)}} = \left( \frac{\partial\boldsymbol{h}^{(t)}}{\partial\boldsymbol{C}^{(t)}} \right)^{T}\frac{\partial L\left( t \right)}{\partial\boldsymbol{h}^{(t)}} = \boldsymbol{o}^{(t)} \odot \left( {1 - {tanh}^{2}\left( \boldsymbol{C}^{(t)} \right)} \right)^{2}$

 

同樣也容易求得$\frac{\partial\boldsymbol{c}^{(t + 1)}}{\partial\boldsymbol{c}^{(t)}} = diag\left( \boldsymbol{f}^{({t + 1})} \right)$

 

所以得到:

$\boldsymbol{\delta}_{C}^{(t)} = \left( \frac{\partial\boldsymbol{h}^{(t)}}{\partial\boldsymbol{c}^{(t)}} \right)^{T}\boldsymbol{\delta}_{h}^{(t)} + {\left( \frac{\partial\boldsymbol{c}^{(t + 1)}}{\partial\boldsymbol{c}^{(t)}} \right)^{T}\boldsymbol{\delta}}_{C}^{(t + 1)} = \boldsymbol{o}^{(t)} \odot \left( {1 - {tanh}^{2}\left( \boldsymbol{C}^{(t)} \right)} \right)^{2}{\odot \boldsymbol{\delta}}_{h}^{(t)} + \boldsymbol{f}^{({t + 1})} \odot \boldsymbol{\delta}_{C}^{(t + 1)}$

 

現在,我們能計算$\boldsymbol{\delta}_{h}^{(t)}$和$\boldsymbol{\delta}_{C}^{(t)}$了,有了它們,計算變數的梯度就比較容易了,這裡只以計算$\boldsymbol{W}_{f}$的梯度計算為例:

我們令${\boldsymbol{z}^{(t)} = \boldsymbol{W}}_{f}\boldsymbol{h}^{(t - 1)} + \boldsymbol{U}_{f}\boldsymbol{x}^{(t)} + \boldsymbol{b}_{f}$,則:

$\frac{\partial L}{\partial\boldsymbol{W}_{f}} = {\sum\limits_{t = 1}^{T}\left( \frac{\partial\boldsymbol{C}_{t}}{\partial\boldsymbol{z}^{(t)}} \right)^{T}}\frac{\partial L}{\partial\boldsymbol{C}_{t}}\left( \boldsymbol{h}^{(t - 1)} \right)^{T}$

$d\boldsymbol{C}^{(t)} = \boldsymbol{C}^{(t - 1)} \odot d\boldsymbol{f}^{(t)} = diag\left( \boldsymbol{C}^{({t - 1})} \right)\left( {\left( {\boldsymbol{f}^{(t)} \odot \left( {1 - \boldsymbol{f}^{(t)}} \right)} \right) \odot d\boldsymbol{z}^{(t)}} \right) = diag\left( \boldsymbol{C}^{({t - 1})} \right)\left( {diag\left( {\boldsymbol{f}^{(t)} \odot \left( {1 - \boldsymbol{f}^{(t)}} \right)} \right)d\boldsymbol{z}^{(t)}} \right) = diag\left( {\boldsymbol{f}^{(t)} \odot \left( {1 - \boldsymbol{f}^{(t)}} \right) \odot \boldsymbol{C}^{({t - 1})}} \right)d\boldsymbol{z}^{(t)}$

 

所以$\frac{\partial\boldsymbol{C}_{t}}{\partial\boldsymbol{z}^{(t)}} = diag\left( {\boldsymbol{f}^{(t)} \odot \left( {1 - \boldsymbol{f}^{(t)}} \right) \odot \boldsymbol{C}^{({t - 1})}} \right)$

所以得到:

 $\frac{\partial L}{\partial\boldsymbol{W}_{f}} = {\sum\limits_{t = 1}^{T}\left\lbrack {\boldsymbol{\delta}_{C}^{(t)} \odot \boldsymbol{C}^{(t - 1)} \odot \boldsymbol{f}^{(t)} \odot \left\lbrack {1 - \boldsymbol{f}^{(t)}} \right\rbrack} \right\rbrack}\left( \boldsymbol{h}^{(t - 1)} \right)^{T}$

 

其他變數的梯度按照上述類似的方式可依次求得,在這裡不做過多敘述。


 

5.3 LSTM 能改善梯度消失的原因

首先需要明確的是,RNN 中的梯度消失/梯度爆炸和普通的 MLP 或者深層 CNN 中梯度消失/梯度爆炸的含義不一樣。MLP/CNN 中不同的層有不同的引數,各是各的梯度;而 RNN 中同樣的權重在各個時間步共享,最終的梯度$~g$= 各個時間步的梯度$g^{(t)}$之和。

 

因此,RNN 中總的梯度是不會消失的。即便梯度越傳越弱,那也只是遠距離的梯度消失,由於近距離的梯度不會消失,所有梯度之和便不會消失。RNN 所謂梯度消失的真正含義是,梯度被近距離梯度主導,導致模型難以學到遠距離的依賴關係。

 

LSTM 中梯度的傳播有很多條路徑,但$\boldsymbol{C}^{(t)} = \boldsymbol{C}^{(t - 1)}\bigodot\boldsymbol{f}^{(t)} + \boldsymbol{a}^{(t)}\bigodot\boldsymbol{i}^{(t)}$這條路徑上只有逐元素相乘和相加的操作,梯度流最穩定;但是其他路徑上梯度流與普通 RNN 類似,照樣會發生相同的權重矩陣反覆連乘。

 

由於總的遠距離梯度 = 各條路徑的遠距離梯度之和,即便其他遠距離路徑梯度消失了,只要保證有一條遠距離路徑(就是上面說的那條高速公路)梯度不消失,總的遠距離梯度就不會消失(正常梯度 + 消失梯度 = 正常梯度)。因此 LSTM 通過改善一條路徑上的梯度問題拯救了總體的遠距離梯度。

 

如果本文對您有所幫助的話,不妨點下“推薦”讓它能幫到更多的人,謝謝。


 

參考資料

(歡迎轉載,轉載請註明出處。歡迎留言或溝通交流: lxwalyw@gmail.com)

相關文章