從乘法求導法則到BPTT演算法

Hisi發表於2022-04-08

本文為手稿,旨在搞清楚為什麼BPTT演算法會多路反向求導,而不是一個感性的認識。

 

假設我們要對E3求導(上圖中的L3),那麼則有:

所以S2是W的函式,也就是說,我們不能說:

 因為WS2 = WS2(w),S2裡面包含了W這個變數,S2是W的函式,也許有人會說:“S2裡面的W是常數吧”,那麼請想一想S2的一般表示式。(這裡我其實還是有點過不去,但是我覺得應該是這樣的,不知道各位是否有理解方法)

所以有:

 而對函式WS2(w)求導(對W求導),結果為:

 S02和W2在RNN中的位置為:

 

 

 再次注意,上面兩個值不是變數,是一個具體的值。

 

然後再求(WS1)`:

另外關於W1,這裡我不太清楚是否繼續要用W2,因為畢竟是對第t=3時刻的W求導,如果後面知道了,再改也不遲。

 

 繼續求下去:

 我們假設S-1是全0的向量,那麼S0`就會是0.

 

然後,我們把上面分開求的結果合併起來,直接計算S3對W的導數:

 

 

 最後一行就是最終的結果,其實這三項分別對應:

 下面是數學表示: 

所以,

BPTT反向求導為什麼必然會有多路,實際上是因為 S2是W的函式,所以要運用乘法求導法則,最後完全求出(S2W)`之後,便可以寫成這樣的形式:

 

 以下是完整草稿:

 

 

 

 

 

 

 本文截圖部分來自我的NLP課程喬波老師的PPT。

相關文章