pytorch lstm原始碼解讀
最近閱讀了pytorch中lstm的原始碼,發現其中有很多值得學習的地方。
首先檢視pytorch當中相應的定義
\begin{array}{ll} \\
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
h_t = o_t \odot \tanh(c_t) \\
\end{array}
對應公式:
圈1:
f
t
=
σ
(
W
i
f
x
t
+
b
i
f
+
W
h
f
h
t
−
1
+
b
h
f
)
f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf})
ft=σ(Wifxt+bif+Whfht−1+bhf)
圈2:
i
t
=
σ
(
W
i
i
x
t
+
b
i
i
+
W
h
i
h
t
−
1
+
b
h
i
)
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi})
it=σ(Wiixt+bii+Whiht−1+bhi)
圈3:
g
t
=
tanh
(
W
i
g
x
t
+
b
i
g
+
W
h
g
h
t
−
1
+
b
h
g
)
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg})
gt=tanh(Wigxt+big+Whght−1+bhg)
圈4:
o
t
=
σ
(
W
i
o
x
t
+
b
i
o
+
W
h
o
h
t
−
1
+
b
h
o
)
o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho})
ot=σ(Wioxt+bio+Whoht−1+bho)
圈5:
c
t
=
f
t
⊙
c
t
−
1
+
i
t
⊙
g
t
c_t = f_t \odot c_{t-1} + i_t \odot g_t
ct=ft⊙ct−1+it⊙gt
圈6:
h
t
=
o
t
⊙
tanh
(
c
t
)
h_t = o_t \odot \tanh(c_t)
ht=ot⊙tanh(ct)
呼叫lstm的相應程式碼如下:
import torch
import torch.nn as nn
bilstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2, bidirectional=True)
input = torch.randn(5, 3, 10)
h0 = torch.randn(4, 3, 20)
c0 = torch.randn(4, 3, 20)
#with open('D://test//input1.txt','w') as f:
# f.write(str(input))
#with open('D://test//h0.txt','w') as f:
# f.write(str(h0))
#with open('D://test//c0.txt','w') as f:
# f.write(str(c0))
output, (hn, cn) = bilstm(input, (h0, c0))
print('output shape: ', output.shape)
print('hn shape: ', hn.shape)
print('cn shape: ', cn.shape)
這裡的input = (seq_len, batch, input_size),h_0 (num_layers * num_directions, batch, hidden_size),c_0 (num_layers * num_directions, batch, hidden_size)
觀察初始化部分的原始碼
可以看出這裡當為lstm層的時候,gate_size = 4*hidden_size
這裡當bidirectional = True時num_directions = 2,當bidirectional = False時num_directions = 1。
self._flat_weigts_names中的數值,因為這裡總共定義了兩層,所以’weight_ih_l0’ = [80,10],‘weight_hh_l0’ = [80,20],‘bias_ih_l0’ = [80],‘bias_hh_l0’ = [80],‘weight_ih_l0_reverse’ = [80,10],‘weight_hh_l0_reverse’ = [80,20],‘bias_ih_l0_reverse’ = [80],‘bias_hh_l0_reverse’ = [80]
‘weight_ih_l1’ = [80,40],‘weight_hh_l1’ = [80,20],‘bias_ih_l1’ = [80],‘bias_hh_l1’ = [80]
‘weight_ih_l1_reverse’ = [80,40],‘weight_hh_l1_reverse’ = [80,20],‘bias_ih_l1_reverse’ = [80],‘bias_hh_l1_reverse’ = [80]
關於這些陣列的意義回讀一下之前的註釋內容
這裡面的weight_ih_l[k] = [80,10],其中的80是由4hidden_size = 420得到的,這4個引數分別為W_ii,W_if,W_ig,W_io,而weight_ih_l[k]是由這四個引數拼接得來的[80,10],同理可得到對應的weight_ih_l[k],weight_hh_l[k],bias_ih_l[k],bias_hh_l[k]的相應的含義。
其中,input = [5,3,10],h0 = [4,3,20],c0 = [4,3,20]
對應的lstm結構圖如下所示
h0中的[4,3,20]中的h0[0],h0[1],h0[2],h0[3]分別對應著h[0],h[1],h[2],h[3],每一個的shape都等於[3,20]
同理c0的原理一致。
對於公式進行分析
對於第一層的內容:
公式1:
f
t
=
σ
(
W
i
f
[
20
,
10
]
x
t
+
b
i
f
[
20
]
+
W
h
f
[
20
,
20
]
h
t
−
1
+
b
h
f
[
20
]
)
f_t = \sigma(W_{if}[20,10] x_t + b_{if}[20] + W_{hf}[20,20] h_{t-1} + b_{hf}[20])
ft=σ(Wif[20,10]xt+bif[20]+Whf[20,20]ht−1+bhf[20])
公式2:
i
t
=
σ
(
W
i
i
[
20
,
10
]
x
t
+
b
i
i
[
20
]
+
W
h
i
[
20
,
20
]
h
t
−
1
+
b
h
i
[
20
]
)
i_t = \sigma(W_{ii}[20,10] x_t + b_{ii}[20] + W_{hi}[20,20] h_{t-1} + b_{hi}[20])
it=σ(Wii[20,10]xt+bii[20]+Whi[20,20]ht−1+bhi[20])
公式3:
g
t
=
tanh
(
W
i
g
[
20
,
10
]
x
t
+
b
i
g
[
20
]
+
W
h
g
[
20
,
20
]
h
t
−
1
+
b
h
g
[
20
]
)
g_t = \tanh(W_{ig}[20,10] x_t + b_{ig}[20] + W_{hg}[20,20] h_{t-1} + b_{hg}[20])
gt=tanh(Wig[20,10]xt+big[20]+Whg[20,20]ht−1+bhg[20])
公式4:
o
t
=
σ
(
W
i
o
[
20
,
10
]
x
t
+
b
i
o
[
20
]
+
W
h
o
[
20
,
20
]
h
t
−
1
+
b
h
o
[
20
]
)
o_t = \sigma(W_{io}[20,10] x_t + b_{io}[20] + W_{ho}[20,20] h_{t-1} + b_{ho}[20])
ot=σ(Wio[20,10]xt+bio[20]+Who[20,20]ht−1+bho[20])
公式5:
c
t
=
f
t
[
20
,
20
]
⊙
c
t
−
1
[
20
,
20
]
+
i
t
[
20
,
20
]
⊙
g
t
[
20
,
20
]
c_t = f_t[20,20] \odot c_{t-1}[20,20] + i_t[20,20] \odot g_t[20,20]
ct=ft[20,20]⊙ct−1[20,20]+it[20,20]⊙gt[20,20]
公式6:
h
t
=
o
t
[
20
,
20
]
⊙
tanh
(
c
t
)
[
20
,
20
]
h_t = o_t[20,20] \odot \tanh(c_t)[20,20]
ht=ot[20,20]⊙tanh(ct)[20,20]
對於第二層的內容:
公式1:
f
t
=
σ
(
W
i
f
[
20
,
40
]
x
t
+
b
i
f
[
20
]
+
W
h
f
[
20
,
20
]
h
t
−
1
+
b
h
f
[
20
]
)
f_t = \sigma(W_{if}[20,40] x_t + b_{if}[20] + W_{hf}[20,20] h_{t-1} + b_{hf}[20])
ft=σ(Wif[20,40]xt+bif[20]+Whf[20,20]ht−1+bhf[20])
公式2:
i
t
=
σ
(
W
i
i
[
20
,
40
]
x
t
+
b
i
i
[
20
]
+
W
h
i
[
20
,
20
]
h
t
−
1
+
b
h
i
[
20
]
)
i_t = \sigma(W_{ii}[20,40] x_t + b_{ii}[20] + W_{hi}[20,20] h_{t-1} + b_{hi}[20])
it=σ(Wii[20,40]xt+bii[20]+Whi[20,20]ht−1+bhi[20])
公式3:
g
t
=
tanh
(
W
i
g
[
20
,
40
]
x
t
+
b
i
g
[
20
]
+
W
h
g
[
20
,
20
]
h
t
−
1
+
b
h
g
[
20
]
)
g_t = \tanh(W_{ig}[20,40] x_t + b_{ig}[20] + W_{hg}[20,20] h_{t-1} + b_{hg}[20])
gt=tanh(Wig[20,40]xt+big[20]+Whg[20,20]ht−1+bhg[20])
公式4:
o
t
=
σ
(
W
i
o
[
20
,
40
]
x
t
+
b
i
o
[
20
]
+
W
h
o
[
20
,
20
]
h
t
−
1
+
b
h
o
[
20
]
)
o_t = \sigma(W_{io}[20,40] x_t + b_{io}[20] + W_{ho}[20,20] h_{t-1} + b_{ho}[20])
ot=σ(Wio[20,40]xt+bio[20]+Who[20,20]ht−1+bho[20])
公式5:
c
t
=
f
t
[
20
,
20
]
⊙
c
t
−
1
[
20
,
20
]
+
i
t
[
20
,
20
]
⊙
g
t
[
20
,
20
]
c_t = f_t[20,20] \odot c_{t-1}[20,20] + i_t[20,20] \odot g_t[20,20]
ct=ft[20,20]⊙ct−1[20,20]+it[20,20]⊙gt[20,20]
公式6:
h
t
=
o
t
[
20
,
20
]
⊙
tanh
(
c
t
)
[
20
,
20
]
h_t = o_t[20,20] \odot \tanh(c_t)[20,20]
ht=ot[20,20]⊙tanh(ct)[20,20]
相關文章
- LSTM Keras下的程式碼解讀Keras
- PostgreSQL 原始碼解讀(3)- 如何閱讀原始碼SQL原始碼
- 深入解析xLSTM:LSTM架構的演進及PyTorch程式碼實現詳解架構PyTorch
- WeakHashMap,原始碼解讀HashMap原始碼
- Handler原始碼解讀原始碼
- Laravel 原始碼解讀Laravel原始碼
- Swoft 原始碼解讀原始碼
- SDWebImage原始碼解讀Web原始碼
- MJExtension原始碼解讀原始碼
- Masonry原始碼解讀原始碼
- HashMap原始碼解讀HashMap原始碼
- Redux原始碼解讀Redux原始碼
- require() 原始碼解讀UI原始碼
- ZooKeeper原始碼解讀原始碼
- FairyGUI原始碼解讀AIGUI原始碼
- 【C++】【原始碼解讀】std::is_same函式原始碼解讀C++原始碼函式
- vuex 原始碼:原始碼系列解讀總結Vue原始碼
- Laravel 原始碼的解讀Laravel原始碼
- reselect原始碼解讀原始碼
- ThreadLocal 原始碼解讀thread原始碼
- Redux原始碼完全解讀Redux原始碼
- Seajs原始碼解讀JS原始碼
- Axios 原始碼解讀iOS原始碼
- HashMap原始碼個人解讀HashMap原始碼
- Vue原始碼解讀一Vue原始碼
- Slim 框架原始碼解讀框架原始碼
- ReentrantLock原始碼解讀ReentrantLock原始碼
- MJRefresh原始碼解讀原始碼
- GetBean原始碼全面解讀Bean原始碼
- LifeCycle原始碼解讀原始碼
- LinkedHashMap原始碼解讀HashMap原始碼
- ConcurrentHashMap原始碼解讀HashMap原始碼
- Disruptor-原始碼解讀原始碼
- webpack bootstrap原始碼解讀Webboot原始碼
- Kafka Eagle 原始碼解讀Kafka原始碼
- ThreadLocal原始碼解讀thread原始碼
- Masonry 原始碼解讀(上)原始碼
- Masonry 原始碼解讀(下)原始碼