詳細闡述基於時間的反向傳播演算法(Back-Propagation Through Time,BPTT)

hearthougan發表於2018-09-20

上一節我們說了詳細展示RNN的網路結構以及前向傳播,在瞭解RNN的結構之後,如何訓練RNN就是一個重要問題,訓練模型就是更新模型的引數,也就是如何進行反向傳播,也就意味著如何對引數進行求導。本篇內容就是詳細介紹RNN的反向傳播演算法,即BPTT。


首先讓我們來用動圖來表示RNN的損失是如何產生的,以及如何進行反向傳播,如下圖所示。

上面兩幅圖片,已經很詳細的展示了損失是如何產生的, 以及如何來對引數求導,這是忽略細節的RNN反向傳播流程,我相信已經描述的非常清晰了。下圖(來自trask)描述RNN詳細結構中反向傳播的過程。

有了清晰的反向傳播的過程,我們接下來就需要進行理論的推到,由於符號較多,為了不至於混淆,根據下圖,現標記符號如表格所示:

公式符號表
符號 含義

K

輸入向量的大小(one-hot長度,也是詞典大小)

T

輸入的每一個序列的長度

H

隱藏層神經元的個數

X=\left \{ x_{1},x_{2},x_{3}....,x_{T} \right \}

樣本集合

x_{t}\epsilon \mathbb{R}^{K\times 1}

t時刻的輸入

y_{t}\epsilon \mathbb{R}^{K\times 1}

t時刻經過Softmax層的輸出。

\hat{y}_{t}\epsilon \mathbb{R}^{K\times 1}

t時刻輸入樣本的真實標籤

L_{t}

t時刻的損失函式,使用交叉熵函式,

L_t=-\hat{y}_t^Tlog(y_t)

L

序列對應的損失函式:

L=\sum\limits_t^T L_t

RNN的反向傳播是每處理完一個樣本就需要對引數進行更新,因此當執行完一個序列之後,總的損失函式就是各個時刻所得的損失之和。

s_{t}\epsilon \mathbb{R}^{H\times 1}

t個時刻RNN隱藏層的輸入。

h_{t}\epsilon \mathbb{R}^{H\times 1}

第t個時刻RNN隱藏層的輸出。

z_{t}\epsilon \mathbb{R}^{H\times 1}

輸出層的輸入,即Softmax函式的輸入

W\epsilon \mathbb{R}^{H\times K}

輸入層與隱藏層之間的權重。

U\epsilon \mathbb{R}^{H\times H}

上一個時刻的隱藏層 與 當前時刻隱藏層之間的權值。

V\epsilon \mathbb{R}^{K\times H}

隱藏層與輸出層之間的權重。

                                                                     \begin{matrix} \: \: \: \: \: \: \: \: \; \; \; \; \; \; \; \; \; \; \; \; \; s_t=Uh_{t-1}+Wx_t+b\\ \\ h_t=\sigma(s_t)\\ \\ \; \; \; \; z_t=Vh_t+c\\ \\ \; \; \; \; \; \; \; \; \; \; y_t=\mathrm{softmax}(z_t) \end{matrix}

我們對引數V,c求導比較方便,只有每一時刻的輸出對應的損失與V,c相關,可以直接進行求導,即:

                                                  \frac{\partial L}{\partial V} =\sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial V} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial z_{t}} \frac{\partial z_{t}}{\partial V} = \sum\limits_{t=1}^{\tau}(\hat{ y}_{t}-y_{t}) (h_{t})^T

                                                 \frac{\partial L}{\partial c} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial c} = \sum\limits_{t=1}^{T}\frac{\partial L_{t}}{\partial z_{t}} \frac{\partial z_{t}}{\partial c} = \sum\limits_{t=1}^{T}y_{t} - \hat{y}_{t}

要對引數W,U,b進行更新,就不那麼容易了,因為引數W,U,b雖是共享的,但是他們不只是對第t刻的輸出做出了貢獻,同樣對t+1時刻隱藏層的輸入s_{t+1}做出了貢獻,因此在對W,U,b引數求導的時候,需要從後向前一步一步求導。

假設我們在對t時刻的引數W,U,b求導,我們利用鏈式法則可得出:

                                                                   \frac{\partial L}{\partial W}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial W}

                                                                    \frac{\partial L}{\partial U}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial U}

                                                                    \frac{\partial L}{\partial b}=\frac{\partial L}{\partial h_{t}}\frac{\partial h_{t}}{\partial s_{t}}\frac{\partial s_{t}}{\partial b}

我們發現對W,U,b進行求導的時候,都需要先求出\frac{\partial L}{\partial h_{t}},因此我們設:

                                                                         \delta ^{t}=\frac{\partial L}{\partial h_{t}}=\frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}

那麼我們現在需要先求出\delta ^{t},則:

                                                                        \frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}=V^{T}(y_{t}-\hat{y}_{t})

                                                                   \begin{matrix} \frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}=U^{T}\delta ^{t+1}\odot \sigma ^{'}(z_{t+1})\\ \\ =U^{T}diag(\delta ^{t+1}) \sigma ^{'}(z_{t+1})\\ \\ =U^{T}diag(\sigma ^{'}(z_{t+1}))\delta ^{t+1} \\ \\ =U^{T}diag(1-h_{t+1}^{2})\delta ^{t+1} \end{matrix}

注:在求解啟用函式導數時,是將已知的部分求導之後,然後將它和啟用函式導數部分進行哈達馬乘積。啟用函式的導數一般是和前面的進行哈達馬乘積,這裡的啟用函式是雙曲正切,用矩陣中對角線元素表示向量中各個值的導數,可以去掉哈達馬乘積,轉化為矩陣乘法。

則:

                                                                \begin{matrix} \delta ^{t}=\frac{\partial L}{\partial h_{t}}=\frac{\partial L}{\partial z_{t}}\frac{\partial z_{t}}{\partial h_{t}}+\frac{\partial L}{\partial h_{t+1}}\frac{\partial h_{t+1}}{\partial h_{t}}\\ \\\: \; \; \; \; \; \; \; \; \; \; \; \; \; \; =V^{T}(y_{t}-\hat{y}_{t})+U^{T}\delta ^{t+1}\odot \sigma ^{'}(z_{t+1})\\ \\ \: \: \: \: \: \: \: \: \: \: \: \: \: \: \: \; =V^{T}(y_{t}-\hat{y}_{t})+U^{T}\delta ^{t+1} (1-h_{t+1}^{2}) \end{matrix}

我們求得\delta ^{t},之後,便可以回到最初對引數的求導,因此有:

                                                           \frac{\partial L}{\partial W} = \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial W} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(x_{t})^T

                                                            \frac{\partial L}{\partial b}= \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial b} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}

                                                           \frac{\partial L}{\partial U} = \sum\limits_{t=1}^{T}\frac{\partial L}{\partial h_{t}} \frac{\partial h_{t}}{\partial U} = \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(h_{t-1})^T

有了各個引數導數之後,我們可以進行引數更新:

                                                          W^{'}=W-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(x_{t})^T

                                                           U^{'}=U-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}(h_{t-1})^T

                                                          V^{'}=V-\theta \sum\limits_{t=1}^{T}(\hat{ y}_{t}-y_{t}) (h_{t})^T

                                                           b^{'}=b-\theta \sum\limits_{t=1}^{T}diag(1-(h_{t})^2)\delta^{t}

                                                          c^{'}=c- \theta \sum\limits_{t=1}^{T}y_{t} - \hat{y}_{t}


參考:

劉建平《迴圈神經網路(RNN)模型與前向反向傳播演算法

李弘毅老師《深度學習》

相關文章