多頭注意力機制是一種用於處理序列資料的神經網路結構,在自然語言處理領域中得到廣泛應用。它可以幫助模型更好地理解和學習輸入序列中的資訊,提高模型在各種任務上的效能。
多頭注意力機制是基於注意力機制的改進版本,它引入了多個注意力頭,每個頭都可以關注輸入序列中不同位置的資訊。透過彙總多個頭的輸出,模型可以更全面地捕捉輸入序列中的特徵。
下面我們用一個簡單的例子來演示如何使用python實現多頭注意力機制。我們將使用pytorch框架來構建模型。
import torch import torch.nn as nn import torch.nn.functional as F class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model self.query_linear = nn.Linear(d_model, d_model) self.key_linear = nn.Linear(d_model, d_model) self.value_linear = nn.Linear(d_model, d_model) self.output_linear = nn.Linear(d_model, d_model) def forward(self, query, key, value): batch_size = query.size(0) query = self.query_linear(query) key = self.key_linear(key) value = self.value_linear(value) query = query.view(batch_size, -1, self.num_heads, self.d_model// self.num_heads).transpose(1,2) key = key.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1,2) value = value.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1,2) scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d_model // self.num_heads) ** 0.5 attention_weights = F.softmax(scores, dim = -1) output = torch.matmul(attention_weights, value) output = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model) return self.output_linear(output) if __name__ == "__main__": query = torch.randn(5,10,20) key = torch.randn(5,10,20) value = torch.randn(5,10,20) multi_head_attention = MultiHeadAttention(d_model = 20, num_heads = 4) output = multi_head_attention(query, key, value) print("output.shape: ", output.shape)
執行上面的程式碼,我們可以看到模型輸出的形狀為(5,10,20),說明多頭注意力機制成功執行並得到了輸出。