淺析注意力(Attention)機制

KevinScott0582發表於2024-11-17

Attention顧名思義,說明這項機制是模仿人腦的注意力機制建立的,我們不妨從這個角度展開理解

2.1 人腦的注意力機制

人腦的注意力機制,就是將有限的注意力資源分配到當前關注的任務,或關注的目標之上,暫時忽略其他不重要的因素,這是人類利用有限的注意力資源從大量資訊中快速篩選出高價值資訊的手段,是人類在長期進化中形成的一種生存機制,極大地提高了資訊處理的效率與準確性。

舉個例子,就以上班為例,今天本該又是摸魚的一天,但你的“恩人”突然交給你一項任務——查詢關於“注意力機制”的資料並總結,並於下班之前向她彙報。於是你不得不放下手上的娛樂節目,轉而應付恩人派下的工作。你選定了“注意力機制”作為關鍵詞開始搜尋,在搜尋引擎的推送下陰差陽錯的看到了這篇博文(這是不可能的),又因為這篇博文關鍵資訊太少而選擇忽略了它,努力一番後又查到了一些資料,彙總的大量初步結果並提交恩人,按時下班,happy ending!

死嘍

上面的例子中其實出現了多次“識別關鍵要素”或“篩選重要資訊”的動作,這便是注意力機制的體現。而深度學習中的注意力機制從本質上講和人類的選擇性注意力機制類似,核心目標也是從眾多資訊中選擇出對當前任務目標更關鍵的資訊。

2.2 為什麼需要Attention

在之前的博文《理解LSTM》中提到過,LSTM透過引入邏輯閘,從結構層面上有效解決了序列長距離依賴問題(梯度消失)。然而,面對超長序列時(例如一段500多詞的文字),LSTM也可能失效。而 Attention 機制可以更好地解決序列長距離依賴問題,並且具有平行計算能力

我們還是以文字問題舉例, 看一看RNN或LSTM處理超長文字序列時會發生什麼?

死嘍

可以看到, 為了理解當前文字,我們有時需要獲得很久之前的歷史狀態下的某些資訊。而RNNs從結構層面上無形中新增了一種假設,那就是當前的文字只和臨近區域的文字具有較強的關聯性,而和距離較遠的上下文關聯不大或沒有關聯。很明顯,這樣的假設是不恰當的,這就限制了RNNs處理文字的長度和理解文字的精度,而Attention的出現則幾乎打破了模型對於文字長度的限制。

採用RNN架構的網路均具有這種侷限, 包括LSTM, GRU等等

為了進一步理解,讓我們從迴圈神經網路的老大難問題——機器翻譯問題入手。
在翻譯任務中,源語言和目標語言的單詞數和語序往往不是一一對應的,這種輸入和輸出都是不定長序列的任務,稱為 Seq2Seq,以英語和德語為例,如下圖所示。

翻譯

為了解決這個問題,我們創造了Encoder-Decoder結構的迴圈神經網路。

  • 它先透過一個Encoder迴圈神經網路讀入所有的待翻譯句子中的單詞,得到一個包含原文所有資訊的中間隱藏層,接著把中間隱藏層狀態輸入Decoder網路,一個詞一個詞的輸出翻譯句子。
  • 這樣子,無論輸入中的關鍵詞語有著怎樣的先後次序,由於都被打包到中間層一起輸入後方網路,我們的Encoder-Decoder網路都可以很好地處理這些詞的輸出位置和形式了。

問題在於,由於中間狀態\(C\)來自輸入網路最後的隱藏層,一般來說它是一個大小固定的向量。既然是大小固定的向量,那麼它能儲存的資訊就是有限的,當句子長度不斷變長,由於後方的decoder網路的所有資訊都來自中間狀態,中間狀態需要表達的資訊就越來越多。在語句資訊量過大時,中間狀態就作為一個資訊的瓶頸阻礙翻譯了。這時我們很容易聯想到,如果網路能夠在處理長文字時懂得篩選關鍵資訊, 而不是將全部文字都作為都作為中間狀態儲存,是不是就可以突破文字長度的限制了?這便是注意力機制的由來。

Encoder-Decoder(編碼-解碼)是深度學習中非常常見的一個模型框架,比如無監督演算法的auto-encoding就是用編碼-解碼的結構設計並訓練的;比如這兩年比較熱的image caption的應用,就是CNN-RNN的編碼-解碼框架;再比如神經網路機器翻譯NMT模型,往往就是LSTM-LSTM的編碼-解碼框架。因此,準確的說,Encoder-Decoder並不是一個具體的模型,而是一類框架。Encoder和Decoder部分可以是任意的文字,語音,影像,影片資料,模型可以採用CNN,RNN,BiRNN、LSTM、GRU等等。所以基於Encoder-Decoder架構,我們可以設計出各種各樣的應用演算法。

2.3 Attention的核心思想

在正式介紹注意力機制之前,我們先要明確以下幾個概念:

  • 查詢(Query):用於記錄模型當前關注的任務資訊,向量形式
  • 鍵(Key):用於記錄輸入序列中每個資訊單元的識別符號或標籤, 用於與Query進行比較,以決定哪些資訊是相關的, 在機器翻譯任務中,Key可能是源語言的每個單詞或短語的特徵向量
  • 值(Value):Value通常包含輸入序列的實際資訊,當Query和Key匹配時,相應的Value值被用於計算輸出
  • 分數(Score): Score又稱為注意力分數,用於表示Query和Key的匹配程度,Score越高,模型對當前資訊單元的關注度越高

我們仍以機器翻譯為例,透過引入注意力機制,讓生成詞不是隻能關注全域性的語義編碼向量c,而是增加了一個“注意力範圍”,表示接下來輸出詞時候要重點關注輸入序列中的哪些部分,然後根據關注的區域來產生下一個輸出,如下圖所示。

翻譯
此時生成目標句子單詞的過程就成了下面的形式: $$ \begin{aligned}&\mathbf{y}_{1}=\mathbf{f}\mathbf{1}(\mathbf{C}_{1})\\&\mathbf{y}_{2}=\mathbf{f}\mathbf{1}(\mathbf{C}_{2},\mathbf{y}_{1})\\&\mathbf{y}_{3}=\mathbf{f}\mathbf{1}(\mathbf{C}_{3},\mathbf{y}_{1},\mathbf{y}_{2})\end{aligned} $$ 這樣一來,由於每個生成詞關注的語義編碼向量都各不相同,且資訊容量都被限定在了一個範圍內,無需一次性關注區域性特徵,也就解決了序列長度過長帶來的問題。

在理解了注意力機制的作用之後,我們就可以對其具體步驟加以描述了(正片開始)。

翻譯

如上圖所示,Attention 通常可以進行如下描述,表示為將 Query(Q) 和 key-value pairs(把 Values 拆分成了鍵值對的形式) 對映到輸出上,其中 query、每個 key、每個 value 都是向量,輸出是 \(V\) 中所有 values 的加權,其中權重是由 Query 和每個 key 計算出來的,計算方法分為三步:

  1. 第一步:計算並比較 Q 和 K 的相似度,用 f 來表示:\(f(Q,K_i)\quad i=1,2,\cdots,m\), 一般第一步計算方法包括四種
  • 點乘(transformer使用):\(f(Q,K_i)=Q^TK_i\)
  • 加權:\(f(Q,K_i)=Q^TWK_i\)
  • 拼接權重:\(f(Q,K_i)=W[Q^T;K_i]\)
  • 感知器:\(f(Q,K_i)=V^T\tanh(WQ+UK_i)\)
  1. 將得到的相似度進行 softmax 操作,進行歸一化,得到注意力分數:\(\alpha_i=softmax(\frac{f(Q,K_i)}{\sqrt{d}_k})\)
  2. 針對計算出來的權重 \(\alpha_{i}\),對 \(V\) 中的所有 values 進行加權求和計算,得到 Attention 向量:\(Attention=\sum_{i=1}^m\alpha_iV_i\)

2.4 Attention程式碼實現

最後附一個Attention機制的程式碼示例:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleAttention(nn.Module):
    def __init__(self, input_dim):
        super(SimpleAttention, self).__init__()
        self.input_dim = input_dim
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # Compute attention scores (dot product of queries and keys)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.input_dim ** 0.5

        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)

        # Weighted sum of values
        output = torch.matmul(attention_weights, V)

        return output, attention_weights

# Example usage
input_dim = 64
seq_length = 10
batch_size = 5

# Dummy input tensor (batch_size, seq_length, input_dim)
x = torch.rand(batch_size, seq_length, input_dim)

# Initialize the attention module
attention = SimpleAttention(input_dim)

# Forward pass
output, attention_weights = attention(x)

print("Output shape:", output.shape)  # Expected: (batch_size, seq_length, input_dim)
print("Attention weights shape:", attention_weights.shape)  # Expected: (batch_size, seq_length, seq_length)

相關文章