關於 RNN 迴圈神經網路的反向傳播求導

極客鋒行發表於2021-01-11

關於 RNN 迴圈神經網路的反向傳播求導

本文是對 RNN 迴圈神經網路中的每一個神經元進行反向傳播求導的數學推導過程,下面還使用 PyTorch 對導數公式進行程式設計求證。

RNN 神經網路架構

一個普通的 RNN 神經網路如下圖所示:

圖片1

其中 \(x^{\langle t \rangle}\) 表示某一個輸入資料在 \(t\) 時刻的輸入;\(a^{\langle t \rangle}\) 表示神經網路在 \(t\) 時刻時的hidden state,也就是要傳送到 \(t+1\) 時刻的值;\(y^{\langle t \rangle}\) 則表示在第 \(t\) 時刻輸入資料傳入以後產生的預測值,在進行預測或 sampling\(y^{\langle t \rangle}\) 通常作為下一時刻即 \(t+1\) 時刻的輸入,也就是說 \(x^{\langle t \rangle}=\hat{y}^{\langle t \rangle}\) ;下面對資料的維度進行說明。

  • 輸入: \(x\in\mathbb{R}^{n_x\times m\times T_x}\) 其中 \(n_x\) 表示每一個時刻輸入向量的長度;\(m\) 表示資料批量數(batch);\(T_x\) 表示共有多少個輸入的時刻(time step)。
  • hidden state:\(a\in\mathbb{R}^{n_a\times m\times T_x}\) 其中 \(n_a\) 表示每一個 hidden state 的長度。
  • 預測:\(y\in\mathbb{R}^{n_y\times m\times T_y}\) 其中 \(n_y\) 表示預測輸出的長度;\(T_y\) 表示共有多少個輸出的時刻(time step)。

RNN 神經元

下圖所示的是一個特定的 RNN 神經元:

圖片2

上圖說明了在第 \(t\) 時刻的神經元中,資料的輸入 \(x^{\langle t \rangle}\) 和上一層的 hidden state \(a^{\langle t \rangle}\) 是如何經過計算得到下一層的 hidden state 和預測輸出 \(\hat{y}^{\langle t \rangle}\)

下面是對五個引數的維度說明:

  • \(W_{aa}\in\mathbb{R}^{n_a\times n_a}\)
  • \(W_{ax}\in\mathbb{R}^{n_a\times n_x}\)
  • \(b_a\in\mathbb{R}^{n_a\times 1}\)
  • \(W_{ya}\in\mathbb{R}^{n_y\times n_a}\)
  • \(b_y\in\mathbb{R}^{n_y\times 1}\)

計算 \(t\) 時刻的 hidden state \(a^{\langle t \rangle}\)

\[\begin{split} z1^{\langle t \rangle} &= W_{aa} a^{\langle t-1 \rangle} + W_{ax} x^{\langle t \rangle} + b_a\\ a^{\langle t \rangle} &= \tanh(z1^{\langle t \rangle}) \end{split} \]

預測 \(t\) 時刻的輸出 \(\hat{y}^{\langle t \rangle}\)

\[\begin{split} z2^{\langle t \rangle} &= W_{ya} a^{\langle t \rangle} + b_y\\ \hat{y}^{\langle t \rangle} &= softmax(z2^{\langle t \rangle}) = \frac{e^{z2^{\langle t \rangle}}}{\sum_{i=1}^{n_y}e^{z2_i^{\langle t \rangle}}} \end{split} \]

RNN 迴圈神經網路反向傳播

在當今流行的深度學習程式設計框架中,我們只需要編寫一個神經網路的結構和負責神經網路的前向傳播,至於反向傳播的求導和引數更新,完全由框架搞定;即便如此,我們在學習階段也要自己動手證明一下反向傳播的有效性。

RNN 神經元的反向傳播

下圖是 RNN 神經網路中的一個基本的神經元,圖中標註了反向傳播所需傳來的引數和輸出等。

圖片3

就如一個全連線的神經網路一樣,損失函式 \(J\) 的導數通過微積分的鏈式法則(chain rule)反向傳播到每一個時間軸上。

為了方便,我們將損失函式關於神經元中引數的偏導符號簡記為 \(\mathrm{d}\mathit{parameters}\) ;例如將 \(\frac{\partial J}{\partial W_{ax}}\) 記為 \(\mathrm{d}W_{ax}\)

圖片4

上圖的反向傳播的實現並沒有包括全連線層和 Softmax 層。

反向傳播求導

計算損失函式關於各個引數的偏導數之前,我們先引入一個計算圖(computation graph),其演示了一個 RNN 神經元的前向傳播和如何利用計算圖進行鏈式法則的反向求導。

image

因為當進行反向傳播求導時,我們需要將整個時間軸的輸入全部輸入之後,才可以從最後一個時刻開始往前傳進行反向傳播,所以我們假設 \(t\) 時刻就為最後一個時刻 \(T_x\)

如果我們想要先計算 \(\frac{\partial\ell}{\partial W_{ax}}\) 所以我們可以從計算圖中看到,反向傳播的路徑:

image

我們需要按部就班的分別對從 \(W_{ax}\) 計算到 \(\ell\) 一路相關的變數進行求偏導,利用鏈式法則,將紅色路線上一路的偏導數相乘到一起,就可以求出偏導數 \(\frac{\partial\ell}{\partial W_{ax}}\) ;所以我們得到:

\[\begin{split} \frac{\partial\ell}{\partial W_{ax}} &= \frac{\partial\ell}{\partial\ell^{\langle t\rangle}} {\color{Red}{ \frac{\partial\ell^{\langle t\rangle}}{\partial\hat{y}^{\langle t\rangle}} \frac{\partial\hat{y}^{\langle t\rangle}}{\partial z2^{\langle t\rangle}} }} \frac{\partial z2^{\langle t\rangle}}{\partial a^{\langle t\rangle}} \frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}} \frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}} \end{split} \]

在上面的公式中,我們僅需要分別求出每一個偏導即可,其中紅色的部分就是關於 \(\mathrm{Softmax}\) 的求導,關於 \(\mathrm{Softmax}\) 求導的推導過程,可以看本人的另一篇部落格: 關於 Softmax 迴歸的反向傳播求導數過程

關於 \(\mathrm{tanh}\) 的求導公式如下:

\[\frac{\partial \tanh(x)} {\partial x} = 1 - \tanh^2(x) \]

所以上面的式子就得到:

\[\begin{split} \frac{\partial\ell}{\partial W_{ax}} &= \frac{\partial\ell}{\partial\ell^{\langle t\rangle}} {\color{Red}{ \frac{\partial\ell^{\langle t\rangle}}{\partial\hat{y}^{\langle t\rangle}} \frac{\partial\hat{y}^{\langle t\rangle}}{\partial z2^{\langle t\rangle}} }} \frac{\partial z2^{\langle t\rangle}}{\partial a^{\langle t\rangle}} \frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}} \frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}}\\ &= {\color{Red}{ (\hat{y}^{\langle t\rangle}-y^{\langle t\rangle}) }} W_{ya} (1-\tanh^2(z1^{\langle t\rangle})) x^{\langle t\rangle} \end{split} \]

我們就可以得到在最後時刻 \(t\) 引數 \(W_{ax}\) 的偏導數。

關於上面式子中的偏導數的計算,除了標量對矩陣的求導,在後面還包括了兩個一個矩陣或向量對另一個矩陣或向量中的求導,實際上這是非常麻煩的一件事。

比如在計算 \(\frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}}\) 偏導數的時候,我們發現 \(z1^{\langle t\rangle}\) 是一個 \(\mathbb{R}^{n_a\times m}\) 的矩陣,而 \(W_{ax}\) 則是一個 \(\mathbb{R}^{n_a\times n_x}\) 的矩陣,這一項就是一個矩陣對另一個矩陣求偏導,如果直接對其求導我們將會得到一個四維的矩陣 \(\mathbb{R}^{n_a\times n_x\times n_a\times m}\)雅可比矩陣 Jacobian matrix);只不過這個高維矩陣中偏導數的值有很多 \(0\)

在神經網路中,如果直接將這個高維矩陣直接生搬硬套進梯度下降裡更新引數是不可行,因為我們需要得到的梯度是關於自變數同型的向量或矩陣而且我們還要處理更高維度的矩陣的乘法;所以我們需要將結果進行一定的處理得到我們僅僅需要的資訊。

一般在深度學習框架中都會有自動求梯度的功能包,這些包(比如 PyTorch )中就只允許一個標量對向量或矩陣求導,其他情況是不允許的,除非在反向傳播的函式裡傳入一個同型的權重向量或矩陣才可以得到導數。

我們先簡單求出一個偏導數 \(\frac{\partial\ell}{\partial W_{ax}}\) 我們下面使用 PyTorch 中的自動求梯度的包進行驗證我們的公式是否正確。

import torch
# 這是神經網路中的一些架構的引數
n_x = 6
n_y = 6
m = 1
T_x = 5
T_y = 5
n_a = 3
# 定義所有引數矩陣
# requires_grad 為 True 表明在涉及這個變數的運算時建立計算圖
# 為了之後反向傳播求導
W_ax = torch.randn((n_a, n_x), requires_grad=True)
W_aa = torch.randn((n_a, n_a), requires_grad=True)
ba = torch.randn((n_a, 1), requires_grad=True)
W_ya = torch.randn((n_y, n_a), requires_grad=True)
by = torch.randn((n_y, 1), requires_grad=True)
# t 時刻的輸入和上一時刻的 hidden state
x_t = torch.randn((n_x, m), requires_grad=True)
a_prev = torch.randn((n_a, m), requires_grad=True)
y_t = torch.randn((n_y, m), requires_grad=True)
# 開始模擬一個神經元 t 時刻的前向傳播
# 從輸入一直到計算出 loss
z1_t = torch.matmul(W_ax, x_t) + torch.matmul(W_aa, a_prev) + ba
z1_t.retain_grad()
a_t = torch.tanh(z1_t)
a_t.retain_grad()
z2_t = torch.matmul(W_ya, a_t) + by
z2_t.retain_grad()
y_hat = torch.exp(z2_t) / torch.sum(torch.exp(z2_t), dim=0)
y_hat.retain_grad()
loss_t = -torch.sum(y_t * torch.log(y_hat), dim=0)
loss_t.retain_grad()
# 對最後的 loss 標量開始進行反向傳播求導
loss_t.backward()
# 我們就可以得到 W_ax 的導數
# 儲存在字尾 _autograd 變數中,表明是由框架自動求導得到的
W_ax_autograd = W_ax.grad
# 檢視框架計算得到的導數
W_ax_autograd
tensor([[ 0.5252,  1.1938, -0.2352,  1.1571, -1.0168,  0.3195],
        [-1.0536, -2.3949,  0.4718, -2.3213,  2.0398, -0.6410],
        [-0.0316, -0.0717,  0.0141, -0.0695,  0.0611, -0.0192]])
# 我們對自己推演出的公式進行手動計算導數
# 儲存在字尾 _manugrad 變數中,表明是手動由公式計算得到的
W_ax_manugrad = torch.matmul(torch.matmul((y_hat - y_t).T, W_ya).T * (1 - torch.square(torch.tanh(z1_t))), x_t.T)
#torch.matmul(torch.matmul(W_ya.T, y_hat - y_t) * (1 - torch.square(torch.tanh(z1_t))), x_t.T)
# 輸出手動計算的導數
W_ax_manugrad
tensor([[ 0.5195,  1.1809, -0.2327,  1.1447, -1.0058,  0.3161],
        [-1.0195, -2.3172,  0.4565, -2.2461,  1.9737, -0.6202],
        [-0.0309, -0.0703,  0.0138, -0.0681,  0.0599, -0.0188]],
       grad_fn=<MmBackward>)
# 檢視兩種求導結果的之差的 L2 範數
torch.norm(W_ax_manugrad - W_ax_autograd)
tensor(0.1356, grad_fn=<CopyBackwards>)

通過上面的程式設計輸出可以看到,我們手動計算的導數和框架自己求出的導數雖然有一定的誤差,但是一一對照可以大體看到我們手動求出來的導數大體是對的,並沒有說錯的非常離譜。

但上面只是當 \(t=T_x\)\(t\) 時刻是最後一個輸入單元的時候,也就是說所求的關於 \(_W{ax}\) 的導數只是全部導數的一部分,因為引數共享,所以每一時刻的神經元都有對 \(W_{ax}\) 的導數,所以需要將所有時刻的神經元關於 \(W_{ax}\) 的導數全部加起來。

\(t\) 不是最後一時刻,可能是神經網路裡的中間的某一時刻的神經元;也就是說,在進行反向傳播的時候,想要求 \(t\) 時刻的導數,就得等到 \(t+1\) 時刻的導數值傳進來,然後根據鏈式法則才可以計算當前時刻引數的導數。

下面是一個簡易的計算圖,只繪製出了 \(W_ax\)\(\ell\) 的計算中,共涉及到哪些變數(在整個神經網路中的 \(W_{ax}\) 的權重引數是共享的):

image

下面使用一個視訊展示整個神經網路中從 \(W_{ax}\) 到一個資料批量的損失值 \(\ell\) 的大體流向:

計算完 \(\ell\) 之後就可以計算 \(\frac{\partial\ell}{\partial W_{ax}}\) 的導數值,但是 RNN 神經網路的反向傳播區別於全連線神經網路的。

image

然後,我們演示一下如何進行反向傳播的,注意看每一個時刻的 \(a^{\langle t\rangle}\) 的計算都是等 \(a^{\langle t+1\rangle}\) 的導數值傳進來才進行計算的;同樣地,\(W_{ax}\) 導數的計算也不是一步到位的,也是需要等到所有時刻的 \(a\) 的值全部傳到才計算完。

所以對於神經網路中間某一個單元 \(t\) 我們有:

\[\begin{split} \frac{\partial\ell}{\partial W_{ax}} &= {\color{Red}{ \left( \frac{\partial\ell}{\partial a^{\langle t\rangle}} +\frac{\partial\ell}{\partial z1^{\langle t+1\rangle}} \frac{\partial z1^{\langle t+1\rangle}}{\partial a^{\langle t\rangle}} \right) }} \frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}} \frac{\partial z1^{\langle t\rangle}}{\partial W_{ax}} \end{split} \]

關於紅色的部分的意思是需要等到 \(t+1\) 時刻的導數值傳進來,然後才可以進行對 \(t+1\) 時刻關於當前時刻 \(t\) 的引數求導,最後得到引數梯度的一個分量。其實若仔細展開每一個偏導項,就像是一個遞迴一樣,每次求某一時刻的導數總是要從最後一時刻往前傳到當前時刻才可以進行。

多元複合函式的求導法則

如果函式 \(u=\varphi(t)\)\(v=\psi(t)\) 都在點 \(t\) 可導,函式 \(z=f(u,v)\) 在對應點 \((u,v)\) 具有連續偏導數,那麼複合函式 \(z=f[\varphi(t),\psi(t)]\) 在點 \(t\) 可導,且有

\[\frac{\mathrm{d}z}{\mathrm{d}t}=\frac{\partial z}{\partial u}\frac{\mathrm{d}u}{\mathrm{d}t}+\frac{\partial z}{\partial v}\frac{\mathrm{d}v}{\mathrm{d}t} \]

下面使用一張計算圖說明 \(a^{\langle t\rangle}\)\(\ell\) 的計算關係。

image

也就是說第 \(t\) 時刻 \(\ell\) 關於 \(a^{\langle t\rangle}\) 的導數是由兩部分相加組成,也就是說是由兩條路徑反向傳播,這兩條路徑分別是 \(\ell\to\ell^{\langle t\rangle}\to\hat{y}^{\langle t\rangle}\to z2^{\langle t\rangle}\to a^{\langle t\rangle}\)\(\ell\to\ell^{\langle t+1\rangle}\to\hat{y}^{\langle t+1\rangle}\to z2^{\langle t+1\rangle}\to a^{\langle t+1\rangle}\to z1^{\langle t+1\rangle}\to a^{\langle t\rangle}\) ,我們將這兩條路徑導數之和使用 \(\mathrm{d}a_{\mathrm{next}}\) 表示。

所以我們可以得到在中間某一時刻的神經單元關於 \(W_{ax}\) 的導數為:

\[\frac{\partial\ell}{\partial W_{ax}}=\left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t \rangle}\right)\right) x^{\langle t \rangle T} \]

通過同樣的方法,我們就可以得到其它引數的導數:

\[\begin{align} \frac{\partial\ell}{\partial W_{aa}} &= \left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t\rangle}) \right)\right) a^{\langle t-1 \rangle T}\\ \frac{\partial\ell}{\partial b_a} & = \sum_{batch}\left( da_{next} * \left( 1-\tanh^2(z1^{\langle t\rangle}) \right)\right)\\ \end{align} \]

除了傳遞引數的導數,在第 \(t\) 時刻還需要傳送 \(\ell\) 關於 \(z1^{\langle t\rangle}\) 的導數到 \(t-1\) 時刻,將需要傳送到上一時刻的導數記作為 \(\mathrm{d}a_{\mathrm{prev}}\) 我們得到:

\[\begin{split} \mathrm{d}a_{\mathrm{prev}} &= \mathrm{d}a_\mathrm{next}\frac{\partial a^{\langle t\rangle}}{\partial z1^{\langle t\rangle}}\frac{\partial z1^{\langle t\rangle}}{\partial a^{\langle t-1\rangle}}\\ &= { W_{aa}}^T\left(\mathrm{d}a_{\mathrm{next}} * \left( 1-\tanh^2(z1^{\langle t\rangle}) \right)\right) \end{split} \]

可以看到,一個迴圈神經網路的反向傳播實際上是非常複雜的,因為每一時刻的神經元都與引數有計算關係,所以反向傳播時的路徑非常雜亂,其中還涉及到了高維的矩陣,所以在計算時需要對高維矩陣進行一定的矩陣代數轉換才方便導數和更新引數的計算。

相關文章