多頭注意力機制的python實現

小丑_jk發表於2024-07-09

多頭注意力機制是一種用於處理序列資料的神經網路結構,在自然語言處理領域中得到廣泛應用。它可以幫助模型更好地理解和學習輸入序列中的資訊,提高模型在各種任務上的效能。

多頭注意力機制是基於注意力機制的改進版本,它引入了多個注意力頭,每個頭都可以關注輸入序列中不同位置的資訊。透過彙總多個頭的輸出,模型可以更全面地捕捉輸入序列中的特徵。

下面我們用一個簡單的例子來演示如何使用python實現多頭注意力機制。我們將使用pytorch框架來構建模型。

import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.output_linear = nn.Linear(d_model, d_model)
    def forward(self, query, key, value):
        batch_size = query.size(0)
        query = self.query_linear(query)
        key = self.key_linear(key)
        value = self.value_linear(value)
        query = query.view(batch_size, -1, self.num_heads, self.d_model// self.num_heads).transpose(1,2)
        key = key.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1,2)
        value = value.view(batch_size, -1, self.num_heads, self.d_model // self.num_heads).transpose(1,2)
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d_model // self.num_heads) ** 0.5
        attention_weights = F.softmax(scores, dim = -1)
        output = torch.matmul(attention_weights, value)
        output = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)
        return self.output_linear(output)
if __name__ == "__main__":
    query = torch.randn(5,10,20)
    key = torch.randn(5,10,20)
    value = torch.randn(5,10,20)
    multi_head_attention = MultiHeadAttention(d_model = 20, num_heads = 4)
    output = multi_head_attention(query, key, value)
    print("output.shape: ", output.shape)

 執行上面的程式碼,我們可以看到模型輸出的形狀為(5,10,20),說明多頭注意力機制成功執行並得到了輸出。

相關文章