pytorch lstm原始碼解讀

唐僧愛吃唐僧肉發表於2021-01-02

最近閱讀了pytorch中lstm的原始碼,發現其中有很多值得學習的地方。
首先檢視pytorch當中相應的定義

        \begin{array}{ll} \\
            i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
            f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
            g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
            o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
            c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
            h_t = o_t \odot \tanh(c_t) \\
        \end{array}

lstm原理圖
對應公式:
圈1: f t = σ ( W i f x t + b i f + W h f h t − 1 + b h f ) f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) ft=σ(Wifxt+bif+Whfht1+bhf)
圈2: i t = σ ( W i i x t + b i i + W h i h t − 1 + b h i ) i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) it=σ(Wiixt+bii+Whiht1+bhi)
圈3: g t = tanh ⁡ ( W i g x t + b i g + W h g h t − 1 + b h g ) g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) gt=tanh(Wigxt+big+Whght1+bhg)
圈4: o t = σ ( W i o x t + b i o + W h o h t − 1 + b h o ) o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) ot=σ(Wioxt+bio+Whoht1+bho)
圈5: c t = f t ⊙ c t − 1 + i t ⊙ g t c_t = f_t \odot c_{t-1} + i_t \odot g_t ct=ftct1+itgt
圈6: h t = o t ⊙ tanh ⁡ ( c t ) h_t = o_t \odot \tanh(c_t) ht=ottanh(ct)
呼叫lstm的相應程式碼如下:

import torch
import torch.nn as nn
bilstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2, bidirectional=True)
input = torch.randn(5, 3, 10)
h0 = torch.randn(4, 3, 20)
c0 = torch.randn(4, 3, 20)
#with  open('D://test//input1.txt','w')  as  f:
#    f.write(str(input))
#with  open('D://test//h0.txt','w')  as  f:
#    f.write(str(h0))
#with  open('D://test//c0.txt','w')  as  f:
#    f.write(str(c0))
output, (hn, cn) = bilstm(input, (h0, c0))
print('output shape: ', output.shape)
print('hn shape: ', hn.shape)
print('cn shape: ', cn.shape)

這裡的input = (seq_len, batch, input_size),h_0 (num_layers * num_directions, batch, hidden_size),c_0 (num_layers * num_directions, batch, hidden_size)
觀察初始化部分的原始碼
初始化部分的原始碼可以看出這裡當為lstm層的時候,gate_size = 4*hidden_size

這裡當bidirectional = True時num_directions = 2,當bidirectional = False時num_directions = 1。
接下來的初始化部分self._flat_weigts_names中的數值,因為這裡總共定義了兩層,所以’weight_ih_l0’ = [80,10],‘weight_hh_l0’ = [80,20],‘bias_ih_l0’ = [80],‘bias_hh_l0’ = [80],‘weight_ih_l0_reverse’ = [80,10],‘weight_hh_l0_reverse’ = [80,20],‘bias_ih_l0_reverse’ = [80],‘bias_hh_l0_reverse’ = [80]
‘weight_ih_l1’ = [80,40],‘weight_hh_l1’ = [80,20],‘bias_ih_l1’ = [80],‘bias_hh_l1’ = [80]
‘weight_ih_l1_reverse’ = [80,40],‘weight_hh_l1_reverse’ = [80,20],‘bias_ih_l1_reverse’ = [80],‘bias_hh_l1_reverse’ = [80]
關於這些陣列的意義回讀一下之前的註釋內容
之前的註釋內容這裡面的weight_ih_l[k] = [80,10],其中的80是由4hidden_size = 420得到的,這4個引數分別為W_ii,W_if,W_ig,W_io,而weight_ih_l[k]是由這四個引數拼接得來的[80,10],同理可得到對應的weight_ih_l[k],weight_hh_l[k],bias_ih_l[k],bias_hh_l[k]的相應的含義。
其中,input = [5,3,10],h0 = [4,3,20],c0 = [4,3,20]
對應的lstm結構圖如下所示
對應的lstm結構圖h0中的[4,3,20]中的h0[0],h0[1],h0[2],h0[3]分別對應著h[0],h[1],h[2],h[3],每一個的shape都等於[3,20]
同理c0的原理一致。
對於公式進行分析
對於第一層的內容:
公式1: f t = σ ( W i f [ 20 , 10 ] x t + b i f [ 20 ] + W h f [ 20 , 20 ] h t − 1 + b h f [ 20 ] ) f_t = \sigma(W_{if}[20,10] x_t + b_{if}[20] + W_{hf}[20,20] h_{t-1} + b_{hf}[20]) ft=σ(Wif[20,10]xt+bif[20]+Whf[20,20]ht1+bhf[20])
公式2: i t = σ ( W i i [ 20 , 10 ] x t + b i i [ 20 ] + W h i [ 20 , 20 ] h t − 1 + b h i [ 20 ] ) i_t = \sigma(W_{ii}[20,10] x_t + b_{ii}[20] + W_{hi}[20,20] h_{t-1} + b_{hi}[20]) it=σ(Wii[20,10]xt+bii[20]+Whi[20,20]ht1+bhi[20])
公式3: g t = tanh ⁡ ( W i g [ 20 , 10 ] x t + b i g [ 20 ] + W h g [ 20 , 20 ] h t − 1 + b h g [ 20 ] ) g_t = \tanh(W_{ig}[20,10] x_t + b_{ig}[20] + W_{hg}[20,20] h_{t-1} + b_{hg}[20]) gt=tanh(Wig[20,10]xt+big[20]+Whg[20,20]ht1+bhg[20])
公式4: o t = σ ( W i o [ 20 , 10 ] x t + b i o [ 20 ] + W h o [ 20 , 20 ] h t − 1 + b h o [ 20 ] ) o_t = \sigma(W_{io}[20,10] x_t + b_{io}[20] + W_{ho}[20,20] h_{t-1} + b_{ho}[20]) ot=σ(Wio[20,10]xt+bio[20]+Who[20,20]ht1+bho[20])
公式5: c t = f t [ 20 , 20 ] ⊙ c t − 1 [ 20 , 20 ] + i t [ 20 , 20 ] ⊙ g t [ 20 , 20 ] c_t = f_t[20,20] \odot c_{t-1}[20,20] + i_t[20,20] \odot g_t[20,20] ct=ft[20,20]ct1[20,20]+it[20,20]gt[20,20]
公式6: h t = o t [ 20 , 20 ] ⊙ tanh ⁡ ( c t ) [ 20 , 20 ] h_t = o_t[20,20] \odot \tanh(c_t)[20,20] ht=ot[20,20]tanh(ct)[20,20]
對於第二層的內容:
公式1: f t = σ ( W i f [ 20 , 40 ] x t + b i f [ 20 ] + W h f [ 20 , 20 ] h t − 1 + b h f [ 20 ] ) f_t = \sigma(W_{if}[20,40] x_t + b_{if}[20] + W_{hf}[20,20] h_{t-1} + b_{hf}[20]) ft=σ(Wif[20,40]xt+bif[20]+Whf[20,20]ht1+bhf[20])
公式2: i t = σ ( W i i [ 20 , 40 ] x t + b i i [ 20 ] + W h i [ 20 , 20 ] h t − 1 + b h i [ 20 ] ) i_t = \sigma(W_{ii}[20,40] x_t + b_{ii}[20] + W_{hi}[20,20] h_{t-1} + b_{hi}[20]) it=σ(Wii[20,40]xt+bii[20]+Whi[20,20]ht1+bhi[20])
公式3: g t = tanh ⁡ ( W i g [ 20 , 40 ] x t + b i g [ 20 ] + W h g [ 20 , 20 ] h t − 1 + b h g [ 20 ] ) g_t = \tanh(W_{ig}[20,40] x_t + b_{ig}[20] + W_{hg}[20,20] h_{t-1} + b_{hg}[20]) gt=tanh(Wig[20,40]xt+big[20]+Whg[20,20]ht1+bhg[20])
公式4: o t = σ ( W i o [ 20 , 40 ] x t + b i o [ 20 ] + W h o [ 20 , 20 ] h t − 1 + b h o [ 20 ] ) o_t = \sigma(W_{io}[20,40] x_t + b_{io}[20] + W_{ho}[20,20] h_{t-1} + b_{ho}[20]) ot=σ(Wio[20,40]xt+bio[20]+Who[20,20]ht1+bho[20])
公式5: c t = f t [ 20 , 20 ] ⊙ c t − 1 [ 20 , 20 ] + i t [ 20 , 20 ] ⊙ g t [ 20 , 20 ] c_t = f_t[20,20] \odot c_{t-1}[20,20] + i_t[20,20] \odot g_t[20,20] ct=ft[20,20]ct1[20,20]+it[20,20]gt[20,20]
公式6: h t = o t [ 20 , 20 ] ⊙ tanh ⁡ ( c t ) [ 20 , 20 ] h_t = o_t[20,20] \odot \tanh(c_t)[20,20] ht=ot[20,20]tanh(ct)[20,20]

相關文章