RNN的PyTorch實現

大雄的叮噹貓發表於2022-11-19

官方實現

PyTorch已經實現了一個RNN類,就在torch.nn工具包中,透過torch.nn.RNN呼叫。

使用步驟:

  1. 例項化類;
  2. 將輸入層向量和隱藏層向量初始狀態值傳給例項化後的物件,獲得RNN的輸出。

在例項化該類時,需要傳入如下屬性:

  • input_size:輸入層神經元個數;
  • hidden_size:每層隱藏層的神經元個數;
  • num_layers:隱藏層層數,預設設定為1層;
  • nonlinearity:啟用函式的選擇,可選是'tanh'或者'relu',預設設定為'tanh';
  • bias:偏置係數,可選是'True'或者'False',預設設定為'True';
  • batch_first:可選是'True'或者'False',預設設定為'False';
  • dropout:預設設定為0。若為非0,將在除最後一層的每層RNN輸出上引入Dropout層,dropout機率就是該非零值;
  • bidirectional:預設設定為False。若為True,即為雙向RNN。

RNN的輸入有兩個,一個是input,一個是h0。input就是輸入層向量,h0就是隱藏層初始狀態值。
若沒有采用批次輸入,則輸入層向量的形狀為(L, Hin);
若採用批次輸入,且batch_first為False,則輸入層向量的形狀為(L, N, Hin);
若採用批次輸入,且batch_first為True,則輸入層向量的形狀為(N, L, Hin);
對於(N, L, Hin),在文字輸入時,可以按順序理解為(每次輸入幾句話,每句話有幾個字,每個字由多少維的向量表示)。

若沒有采用批次輸入,則隱藏層向量的形狀為(D * num_layers, Hout);
若採用批次輸入,則隱藏層向量的形狀為(D * num_layers, N, Hout);
注意,batch_first的設定對隱藏層向量的形狀不起作用。

RNN的輸出有兩個,一個是output,一個是hn。output包含了每個時間步最後一層的隱藏層狀態,hn包含了最後一個時間步每層的隱藏層狀態。
若沒有采用批次輸入,則輸出層向量的形狀為(L, D * Hout);
若採用批次輸入,且batch_first為False,則輸出層向量的形狀為(L, N, D * Hout);
若採用批次輸入,且batch_first為True,則輸出層向量的形狀為(N, L, D * Hout)。

引數解釋:

  • N代表的是批次大小;
  • L代表的是輸入的序列長度;
  • 若是雙向RNN,則D的值為2;若是單向RNN,則D的值為1;
  • Hin在數值上是輸入層神經元個數;
  • Hout在數值上是隱藏層神經元個數。
import torch
import torch.nn as nn
rnn = nn.RNN(10, 20, 1, batch_first=True)  # 例項化一個單向單層RNN
input = torch.randn(5, 3, 10)
h0 = torch.randn(1, 5, 20)
output, hn = rnn(input, h0)

手寫復現

復現程式碼

import torch
import torch.nn as nn

class MyRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = torch.randn(self.hidden_size, self.input_size) * 0.01
        self.weight_hh = torch.randn(self.hidden_size, self.hidden_size) * 0.01
        self.bias_ih = torch.randn(self.hidden_size)
        self.bias_hh = torch.randn(self.hidden_size)
        
    def forward(self, input, h0):
        N, L, input_size = input.shape
        output = torch.zeros(N, L, self.hidden_size)
        for t in range(L):
            x = input[:, t, :].unsqueeze(2)  # 獲得當前時刻的輸入特徵,[N, input_size, 1]。unsqueeze(n),在第n維上增加一維
            w_ih_batch = self.weight_ih.unsqueeze(0).tile(N, 1, 1)  # [N, hidden_size, input_size]
            w_hh_batch = self.weight_hh.unsqueeze(0).tile(N, 1, 1)  # [N, hidden_size, hidden_size]
            w_times_x = torch.bmm(w_ih_batch, x).squeeze(-1)  # [N, hidden_size]。squeeze(n),在第n維上減小一維
            w_times_h = torch.bmm(w_hh_batch, h0.unsqueeze(2)).squeeze(-1)  # [N, hidden_size]
            h0 = torch.tanh(w_times_x + self.bias_ih  + w_times_h + self.bias_hh)
            output[:, t, :] = h0
        return output, h0.unsqueeze(0)

驗證正確性

my_rnn = MyRNN(10, 20)
input = torch.randn(5, 3, 10)
h0 = torch.randn(5, 20)
my_output, my_hn = my_rnn(input, h0)
print(output.shape == my_output.shape, hn.shape == my_hn.shape)
True True

主要參考

官方說明文件

相關文章