原理分析:
BiLSTM(雙向長短期記憶網路) 是一種迴圈神經網路(RNN)的變體,它在自然語言處理任務中非常有效,其中包括給定一個長句子預測下一個單詞。
這種效果的主要原因包括以下幾點:
長短期記憶網路(LSTM)結構:LSTM 是一種特殊的 RNN,專門設計用於解決長序列依賴問題。相比於普通的 RNN,LSTM 有能力更好地捕捉長距離的依賴關係,因此適用於處理長句子。
雙向性:BiLSTM 透過在輸入序列的兩個方向上進行處理,即前向和後向,使得模型能夠同時捕捉到當前位置前後的上下文資訊。這樣,模型就能夠更全面地理解整個句子的語境,從而更準確地預測下一個單詞。
上下文資訊:BiLSTM 能夠透過記憶單元和門控機制(如輸入門、遺忘門、輸出門)來記憶並使用之前的輸入資訊。這使得模型能夠在預測下一個單詞時考慮到句子中前面的所有單詞,而不僅僅是最近的幾個單詞。
引數共享:由於 LSTM 的引數在整個序列上是共享的,模型能夠利用整個序列的資訊來進行預測,而不是僅僅依賴於當前時刻的輸入。
端到端學習:BiLSTM 可以透過端到端的方式進行訓練,這意味著模型可以直接從原始資料中學習輸入和輸出之間的對映關係,無需手工設計特徵或規則。
總的來說,BiLSTM 結合了雙向處理、長序列依賴建模和上下文資訊的利用,使得它能夠在給定一個長句子的情況下有效地預測下一個單詞。
程式碼實現:
BiLSTM的程式碼相對而言比較難找,很多提供的也不準確。筆者找了幾個執行成功的案例,針對案例中的BiLSTM演算法部分進行分析。
案例一:給定一個長句子預測下一個單詞
原文連結點選此給定一個長句子預測下一個單詞
class BiLSTM(nn.Module):
def __init__(self):
super(BiLSTM, self).__init__()
self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True)
# fc
self.fc = nn.Linear(n_hidden * 2, n_class)
def forward(self, X):
# X: [batch_size, max_len, n_class]
batch_size = X.shape[0]
input = X.transpose(0, 1) # input : [max_len, batch_size, n_class]
hidden_state = torch.randn(1*2, batch_size, n_hidden) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
cell_state = torch.randn(1*2, batch_size, n_hidden) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))
outputs = outputs[-1] # [batch_size, n_hidden * 2]
model = self.fc(outputs) # model : [batch_size, n_class]
return model
model = BiLSTM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
針對def __init__(self)
:
這段程式碼定義了一個名為 BiLSTM
的模型類,它繼承自 nn.Module
類。在 __init__
方法中,首先呼叫 super(BiLSTM, self).__init__()
來初始化父類 nn.Module
,然後建立了一個雙向 LSTM 模型 self.lstm
。
input_size
引數指定了輸入的特徵維度,這裡設定為 n_class
,即輸入資料的特徵數量。hidden_size
引數指定了 LSTM 隱藏狀態的維度,這裡設定為 n_hidden
,即隱藏層的大小。bidirectional=True
表示這是一個雙向 LSTM,即包含前向和後向兩個方向的資訊。接著建立了一個全連線層 self.fc
,其中輸入特徵數量為 n_hidden * 2
,表示雙向 LSTM 輸出的隱藏狀態的維度乘以 2,輸出特徵數量為 n_class
,表示分類的類別數量。
針對def forward(self, X)
:
這段程式碼定義了模型的前向傳播方法forward
。在該方法中,首先接受輸入X
,其維度為[batch_size,max_len, n_class]
,其中batch_size
表示輸入資料的批次大小,max_len
表示序列的最大長度,n_class
表示輸入資料的特徵數量。接著透過transpose
方法將輸入X
的維度重新排列,以適應LSTM模型的輸入要求,即將序列的維度放在第二維上,結果儲存在input
中。
然後,建立了LSTM模型所需的初始隱藏狀態hidden_state
和細胞狀態cell_state
。這裡使用了隨機初始化的狀態,其維度為[num_layers * num_directions, batch_size, n_hidden]
,其中num_layers
表示 LSTM 的層數,預設為 1,num_directions
表示 LSTM 的方向數,預設為 2(雙向)。這裡的 1*2 表示單層雙向 LSTM。
接著,將輸入資料input
和初始狀態傳遞給 LSTM 模型self.lstm
,得到輸出outputs
。最後,取 LSTM 模型輸出的最後一個時間步的隱藏狀態作為模型輸出,即outputs[-1]
,其維度為[batch_size, n_hidden * 2]
,然後透過全連線層self.fc
進行分類,得到模型的輸出model
,其維度為[batch_size, n_class]
,即表示每個類別的得分。
針對
model = BiLSTM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
:
這段程式碼例項化了一個 BiLSTM 模型,並定義了損失函式 CrossEntropyLoss
和最佳化器 Adam
。損失函式用於計算模型輸出與目標標籤之間的誤差,最佳化器用於更新模型的引數,其中學習率 lr
設定為 0.001。
思考:
這個案例相對簡單,便於理解BiLSTM的程式碼設計,如果想要改寫為LSTM,則針對def forward(self,X)
中的num_directions
數值,由2->1,因為在單向 LSTM 中不需要考慮前向和後向兩個方向的隱藏狀態,其他部分保持不變。
class LSTM(nn.Module):
def __init__(self):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden)
# fc
self.fc = nn.Linear(n_hidden, n_class)
def forward(self, X):
# X: [batch_size, max_len, n_class]
batch_size = X.shape[0]
input = X.transpose(0, 1) # input : [max_len, batch_size, n_class]
hidden_state = torch.randn(1, batch_size, n_hidden) # [num_layers(=1), batch_size, n_hidden]
cell_state = torch.randn(1, batch_size, n_hidden) # [num_layers(=1), batch_size, n_hidden]
outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))
outputs = outputs[-1] # [batch_size, n_hidden]
model = self.fc(outputs) # model : [batch_size, n_class]
return model
model = LSTM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
回顧一下LSTM的知識:
1. 三個門:輸入門、輸出門、遺忘門,透過這三個門來決定哪些資訊需要被記憶,哪些需要被忘記。
2. cell state:中文翻譯為細胞態。cell state變數儲存的是當前時刻t及其前面所有時刻的混合資訊,也就是說,在LSTM中,資訊的記憶與維護都是透過cell state變數的。
3. hidden state:LSTM中的hidden_state其實就是cell state的一種過濾之後的資訊,更關注當前時間點的輸出結果。LSTM的hidden state其實就是當前時刻的output。
4. 輸入xt:x_tx當前時間點的輸入。不過,需要注意的是,在LSTM每一個時間步中,最終輸入其實由xt與上一時刻隱狀態ht−1組成。