自注意力機制(2)-多頭自注意層

吃瓜的哲学發表於2024-09-23

多頭自注意層

上一篇描述了單頭多注意層,但在實際應用中,通常使用的是多頭自注意層,多頭自注意層是由多個單頭的組合。

1. 數學形式

輸入:\(X=\{x_1, x_2,...,x_m\}\)\(x_i\)\(d_{in}\times1\)的向量。

引數:每個單頭自注意層都有三個引數矩陣,\(W_q:d_q*d_{in}\); \(W_k:d_q*d_{in}\); \(W_v:d_{out}*d_{in}\),多頭自注意層總共有3\(l\)個引數矩陣,\(l\)表示自注意層的個數。

輸出:每個單頭自注意層的輸出為\(C=\{c_1, c_2,...,c_m\}\)\(c_i\)\(d_{out}\times1\)的向量,多頭的輸出就是有\(l\)個C矩陣,然後將所有單頭自注意層對應位置的輸出做連線。最終的每個輸出\(c_i=[c_i^1; c_i^2; c_i^3;...;c_i^l]\)

2.Pytorch程式碼實現(多頭自注意層)

使用一個大矩陣,將所有引數矩陣並行起來計算。計算過程和單層自注意層相同,最後將多頭注意力的輸出連線起來。

import torch
import torch.nn as nn
from math import sqrt

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_in, d_k, d_out, num_heads=8):
        super(MultiHeadSelfAttention, self).__init__()
        assert d_k % num_heads == 0 and d_out % num_heads == 0   # dk和dout必須是多頭數量的倍數,因為dk和dout表示所有頭的總引數量
        self.din = d_in
        self.dq = d_k
        self.dout = d_out
        self.num_heads = num_heads

        self.Wq = nn.Linear(self.din, self.dq, bias=False)
        self.Wk = nn.Linear(self.din, self.dq, bias=False)
        self.Wv = nn.Linear(self.din, self.dout, bias=False)

        self._norm_fact = 1/sqrt(self.dq//num_heads)  # "//" 為整除運算子


    def forward(self, x):
        m, din = x.shape
        assert din == self.din

        nh = self.num_heads
        dk = self.dq // nh         # 每一個頭的dq大小
        dv = self.dout // nh       # 每一個頭的dout大小

        # 第一步
        Q = self.Wq(x).reshape(m, nh, dk).transpose(0, 1)   # nh*m*dk
        K = self.Wk(x).reshape(m, nh, dk).transpose(0, 1)
        V = self.Wv(x).reshape(m, nh, dv).transpose(0, 1)

        # 第二步
        A = torch.softmax(torch.matmul(Q, K.transpose(1, 2))*self._norm_fact, dim=-1)

        # 第三步
        C = torch.matmul(A, V)   # nh, m, dv

        # 將輸出進行連線
        C = C.transpose(0, 1).reshape(m, self.dout)
        
        return C



相關文章