多頭自注意層
上一篇描述了單頭多注意層,但在實際應用中,通常使用的是多頭自注意層,多頭自注意層是由多個單頭的組合。
1. 數學形式
輸入:\(X=\{x_1, x_2,...,x_m\}\),\(x_i\)是\(d_{in}\times1\)的向量。
引數:每個單頭自注意層都有三個引數矩陣,\(W_q:d_q*d_{in}\); \(W_k:d_q*d_{in}\); \(W_v:d_{out}*d_{in}\),多頭自注意層總共有3\(l\)個引數矩陣,\(l\)表示自注意層的個數。
輸出:每個單頭自注意層的輸出為\(C=\{c_1, c_2,...,c_m\}\),\(c_i\)是\(d_{out}\times1\)的向量,多頭的輸出就是有\(l\)個C矩陣,然後將所有單頭自注意層對應位置的輸出做連線。最終的每個輸出\(c_i=[c_i^1; c_i^2; c_i^3;...;c_i^l]\)。
2.Pytorch程式碼實現(多頭自注意層)
使用一個大矩陣,將所有引數矩陣並行起來計算。計算過程和單層自注意層相同,最後將多頭注意力的輸出連線起來。
import torch
import torch.nn as nn
from math import sqrt
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_in, d_k, d_out, num_heads=8):
super(MultiHeadSelfAttention, self).__init__()
assert d_k % num_heads == 0 and d_out % num_heads == 0 # dk和dout必須是多頭數量的倍數,因為dk和dout表示所有頭的總引數量
self.din = d_in
self.dq = d_k
self.dout = d_out
self.num_heads = num_heads
self.Wq = nn.Linear(self.din, self.dq, bias=False)
self.Wk = nn.Linear(self.din, self.dq, bias=False)
self.Wv = nn.Linear(self.din, self.dout, bias=False)
self._norm_fact = 1/sqrt(self.dq//num_heads) # "//" 為整除運算子
def forward(self, x):
m, din = x.shape
assert din == self.din
nh = self.num_heads
dk = self.dq // nh # 每一個頭的dq大小
dv = self.dout // nh # 每一個頭的dout大小
# 第一步
Q = self.Wq(x).reshape(m, nh, dk).transpose(0, 1) # nh*m*dk
K = self.Wk(x).reshape(m, nh, dk).transpose(0, 1)
V = self.Wv(x).reshape(m, nh, dv).transpose(0, 1)
# 第二步
A = torch.softmax(torch.matmul(Q, K.transpose(1, 2))*self._norm_fact, dim=-1)
# 第三步
C = torch.matmul(A, V) # nh, m, dv
# 將輸出進行連線
C = C.transpose(0, 1).reshape(m, self.dout)
return C