BiLSTM-Attention文字分類

ckxllf發表於2020-04-22

  概述

  這篇以簡單的文字分類為demo,基於pytorch,全面解讀BiLSTM-Attention。

  文字分類實戰

  整體構建

  首先,我們匯入需要的包,包括模型,最佳化器,梯度求導等,將資料型別全部轉化成tensor型別

  import numpy as np

  import torch

  import torch.nn as nn

  import torch.optim as optim

  from torch.autograd import Variable

  import torch.nn.functional as F

  import matplotlib.pyplot as plt

  dtype = torch.FloatTensor

  接下來我們確定一些基本的引數,並且簡單地構造一個資料,實現情感的二分類。資料集中三個句子,一半正,一半負。label中1是好的情感,0是不好的情感。

  embedding_dim = 3

  n_hidden = 5

  num_classes = 2 # 0 or 1

  sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]

  labels = [1, 1, 1, 0, 0, 0]

  接著,我們需要構建詞表,把資料集中出現過的詞拿出來並給它一個編號:

  word_list = " ".join(sentences).split()

  word_list = list(set(word_list))

  word_dict = {w: i for i, w in enumerate(word_list)}

  vocab_size = len(word_dict)

  然後我們定義輸入輸出,輸入其實是每個句子中的每個單詞對應在詞表中的id,將輸入輸出變成Variable,以便於求導:

  inputs = []

  for sen in sentences:

  inputs.append(np.asarray([word_dict[n] for n in sen.split()]))

  targets = []

  for out in labels:

  targets.append(out)

  input_batch = Variable(torch.LongTensor(inputs))

  target_batch = Variable(torch.LongTensor(targets))

  接下來構造模型:

  class BiLSTM_Attention(nn.Module):

  def __init__(self):

  super(BiLSTM_Attention, self).__init__()

  self.embedding = nn.Embedding(vocab_size, embedding_dim)

  self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)

  self.out = nn.Linear(n_hidden * 2, num_classes)

  def attention_net(self, lstm_output, final_state):

  hidden = final_state.view(-1, n_hidden * 2, 1)

  attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)

  soft_attn_weights = F.softmax(attn_weights, 1)

  context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)

  return context, soft_attn_weights.data.numpy()

  def forward(self, X):

  input = self.embedding(X)

  input = input.permute(1, 0, 2)

  hidden_state = Variable(torch.zeros(1*2, len(X), n_hidden))

  cell_state = Variable(torch.zeros(1*2, len(X), n_hidden))

  output, (final_hidden_state, final_cell_state) = self.lstm(input, (hidden_state, cell_state))

  output = output.permute(1, 0, 2)

  attn_output, attention = self.attention_net(output, final_hidden_state)

  return self.out(attn_output), attention

  首先embedding中需要傳入詞表,以及嵌入的維度。有一個雙向LSTM層,還有一個線性層以獲取LSTM中的隱層引數。

  這裡詳細說一下attention層的操作,首先hidden 的維度是 [batch_size, n_hidden * num_directions(=2), 1(=n_layer)],接下來確定attention矩陣,將LSTM輸出與hidden相乘,去掉第三個維度。attn_weights的維度是[batch_size, n_step] ,兩個矩陣相乘後的維度,[batch_size, n_hidden * num_directions(=2), n_step] * [batch_size, n_step, 1] = [batch_size, n_hidden * num_directions(=2), 1],然後去掉了第三個維度的1。這樣再經過softmax函式。再將權重函式與LSTM輸出相乘得到context。最終context的維度就是 [batch_size, n_hidden * num_directions(=2)] 。

  最後在forward方法中操作各個層,進行層的各種操作,獲得輸出和attention矩陣。

  接下來就是將模型例項化,並確定損失函式,最佳化器:

  model = BiLSTM_Attention()

  criterion = nn.CrossEntropyLoss()

  optimizer = optim.Adam(model.parameters(), lr=0.001)

  最後訓練並測試:

  # Training

  for epoch in range(5000):

  optimizer.zero_grad()

  output, attention = model(input_batch)

  loss = criterion(output, target_batch)

  if (epoch + 1) % 1000 == 0:

  print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

  loss.backward() 鄭州哪個人流醫院好

  optimizer.step()

  # Test

  test_text = 'sorry hate you'

  tests = [np.asarray([word_dict[n] for n in test_text.split()])]

  test_batch = Variable(torch.LongTensor(tests))

  # Predict

  predict, _ = model(test_batch)

  predict = predict.data.max(1, keepdim=True)[1]

  if predict[0][0] == 0:

  print(test_text,"is Bad Mean...")

  else:

  print(test_text,"is Good Mean!!")

  最終我們畫圖看下attention中結果:

  fig = plt.figure(figsize=(6, 3)) # [batch_size, n_step]

  ax = fig.add_subplot(1, 1, 1)

  ax.matshow(attention, cmap='viridis')

  ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'], fontdict={'fontsize': 14}, rotation=90)

  ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'], fontdict={'fontsize': 14})

  plt.show()

  除錯

  讀取資料:

  

在這裡插入圖片描述

  轉換文字後的輸入輸出:

  接下來跑完整個迴圈,看到結果,測試集中的這個句子分類為負:

  

在這裡插入圖片描述

  最後得出Attention矩陣


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69945560/viewspace-2687671/,如需轉載,請註明出處,否則將追究法律責任。

相關文章