transformer中的attention機制詳解

老张哈哈哈發表於2024-07-02

transformer中用到的注意力機制包括self-attention(intra-attention)和傳統的attention(cross-attention),本篇文章將在第一節簡述這兩者的差別,第二節詳述self-attention機制,第三節介紹其實現

  • self-attention和attention的區別
  • self-attention機制
  • self-attention實現

self-attention和attention的區別

傳統attention機制

發生在decoder和encoder之間,decoder可以更多的參考encoder中相關的資訊,以便指導其輸出。attention機制可以分為以下三步

  • 計算algnment score

其中hj時encoder輸出的隱狀態, si是decoder 輸出的隱狀態, eij描述的時輸入位置j和輸出位置i的匹配度

  • 匹配度歸一化,這裡使用softmax進行計算

  • 計算context vector

從表示式可看出,decoder 計算所需的context vector 實際上就是輸入隱狀態的加權和。

具體到RNN中,每個時間步應用attention機制的計算步驟如下

  • decoder RNN 接收 token 的嵌入和初始解碼器隱藏狀態。
  • RNN 處理其輸入,產生輸出和新的隱藏狀態向量 si。輸出被丟棄。
  • attention計算:我們使用encoder輸出的所有的隱藏狀態和 decoder 輸出的si 向量來計算此時間步驟的context vector ci。
  • 我們將 si 和 ci 連線成一個向量。
  • 我們將此向量傳遞給前饋神經網路(與模型聯合訓練)。
  • 前饋神經網路的輸出表示此時間步驟的輸出詞。
  • 對下一個時間步驟重複此操作

self-attention

發生在decoder或者encoder內部,將輸出或者輸入序列內部不同位置關聯起來,以計算序列表徵

self-attention機制

self-attention的實現步驟和attention類似,在attention中計算align score時用到了輸入和輸出的hidden state,但是對於self-attention只需要用到一種,即在encoder中的self-attention只用到encoder層輸出的hidden state, decoder中的self-attention只用到decoder層的hidden state

我們將self-attention拆解為兩部分,1. self-attention計算 2. multi-head attention

self attention計算:scaled dot-product attention

  • 獲取encoder輸入的embeding,並計算每個embedding 的query,key,value,下文簡寫為q,k,v。

其中WQ, WK, WV為去要學習的權重矩陣

  • 接下來我們要計算不同位置之間的關聯度。例如我們要計算位置0處的embedding和其他位置embedding的關聯度,參考傳統attention機制align score的計算方法,我們要用位置0處計算得到的hidden states即query 和其他位置處計算的key進行計算。在self-attention中計算過程如下

相較於傳統attention計算align score,self-attention中多了一步scale,即用key維度的開方對qk結果進行縮放。論文中提出這樣做的理由是

We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by 1/sqrt(dk)

即縮放的目的是為了保證softmax有更加穩定的梯度

  • 得到其他位置對於位置0的關聯度/權重之後,我們就可以計算位置0處包含有上下文資訊的context vector:

  • 利用矩陣運算可以同步求出其他位置的context vector
    獲取輸入向量的Q,K,V矩陣

    運用矩陣計算得到每個位置的結果

multi-head attention

相較於單頭注意力,使用多頭注意力的目的在於

  1. 從不同表徵空間挖掘不同位置之間的關聯。

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.

  1. 單頭注意力在計算不同位置間關聯時用到了加權平均,這在一定程度上影響了特徵計算的準確性,因此要用多頭注意力來抵消這種影響

In these models, the number of operations required to relate signals from two arbitrary input or output positions grows in the distance between positions, linearly for ConvS2S and logarithmically for ByteNet. This makes it more difficult to learn dependencies between distant positions [12]. In the Transformer this is reduced to a constant number of operations, albeit at the cost of reduced effective resolution due to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as described in section 3.2.

單頭注意力包含WQ,WK,WV, 產生一個output,多頭注意力則包含n個WQ,WK,WV,這些引數的權重不共享,產生n個output

這n個output被拼接到一起,並對拼接結果再次進行projection得到最終結果

self-attention實現

  1. 首先是single attention
def attention(query, key, value, mask=None, dropout=None):
    # 獲取維度, query, key, value 的size 均為(batch_size,  n_head, seq_length, hidden_state_length)
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(d_k)
    if mask:
        scores = score.masked_fill(mask==0, -1e9)
    p_atten = scores.softmax(dim = -1)
    if dropout:
        p_atten = dropout(p_atten)
    return torch.matmul(p_atten, vakue), p_atten

在transformer decoder中會在self-attention中使用mask,在encoder中不會用到。因為本篇文章主要講解self-attention因此沒有講解mask的使用,下一篇講解transformer的文章中會具體分析self-attention在decoder和encoder中的區別。

  1. multi head attention實現
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model / h
        self.h = h 
        self.wq =  nn.Linear(d_model, d_model)
        self.wk =  nn.Linear(d_model, d_model)
        self.wv =  nn.Linear(d_model, d_model)
        self.wo =  nn.Linear(d_model, d_model)
        self.atten = None 
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, query, key, value, mask=None):
        if mask:
            # same mask applied to all heads
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        query = self.wq(query).view(nbatches, -1, self.h, self.dk).transpose(1, 2)
        key = self.wk(key).view(nbatches, -1, self.h, self.dk).transpose(1, 2)
        value = self.wv(value).view(nbatches, -1, self.h, self.dk).transpose(1, 2)
        
        x, self.atten = attention(query, key, value, mask, self.dropout)
        # concat n heads outputs 
        x = (
              x.transpose(1,2)
              .contiguous()
              .view(nbatches, -1, self.h*self.d_k))
        del query
        del key
        del value
        return self.wo(x)

從multihead attention的實現中看出,實際上是將維度為(nbatches, seq_length, d_model)的矩陣,利用矩陣變換,得到了一個 (nbatches,h, seq_length,dk)的矩陣,且h*dk = d_model。

在attention中計算的時候所有head平行計算,得到一個(nbatches,h, seq_length,dk)的輸出,對這個輸出結果在進行矩陣變換得到 (nbatches, seq_length, d_model)的矩陣。完成了所謂的‘矩陣拼接’

拼接後的矩陣經過wo計算得到最終結果

ref:
Attention is all you need
The Illustrated Transformer
Visualizing A Neural Machine Translation Model (Mechanics of Seq2seq Models With Attention)
The Annotated Transformer

相關文章