Attention機制全流程詳解與細節學習筆記

知了愛啃程式碼發表於2020-12-02

回顧

上一篇筆記裡,我們的初學筆記裡,我們已經對Attention機制的主要內容作了全面的介紹,這篇筆記,主要是補充我在第二次學習Attention機制時對一些細節的理解與記錄。下面我會把整個Attention機制的主要流程通過另外一種方式再過一遍,對其中的一些關鍵引數的維度,這裡也會有所標註。溫故而知新,挺好。

Attention的全流程

image-20201202015719320

我們用上面的圖來說明Attention機制的整個工作流程。

首先,是Decoder部分:

原始輸入是語料分詞後的token_id,分批次傳入Embedding層得到詞向量,再將詞向量傳入Encoder中的特徵提取器進行特徵提取,這裡使用的是RNN系列的模型(RNN、LSTM、GRU),用RNNs代稱,為了更好的捕捉一個句子前後的語義特徵,我們這裡使用雙向的RNNs。兩個方向的RNNs所產生的兩部分隱藏層狀態拼接成一個狀態hs進行輸出。這是後面Attention所需要用到的重要狀態值,它包含了各個輸入詞的語義,在普通的Seq2Seq模型中,它就是生成的語義編碼c。

再看Decoder部分:

image-20201202015849920

解碼部分與編碼部分類似,t-1時刻RNNs輸出一個隱藏狀態h_(t-1),該隱藏狀態在圖中要傳下去,與編碼產生的h_s進行計算,得到一個Score:

image-20201202015804994

注意,當前時刻用到的Decoder的隱藏態h_t-1代表是上一時刻Decoder單元輸出的,因為在RNN中,當前時刻的輸入來自於上一時刻的輸出。

這個Score怎麼計算的呢,一般我們常用的兩個方法如下:

方法一:

兩個向量h_t與h_s直接進行點乘得到Score。中間也可以加一個神經網路引數去學習。點乘其實是求了兩個向量的相似度,這看起來也很合理,越相似的權重越大。

方法二:

感知機的方式。可以寫作:v^T tanh(W_h h_i + W_s s_t + b_attn)

下一步,將Score進行softmax歸一化,得到權重Attention Weight。

再將Attention Weight與h_s作reduce_sum,其實也就是加權求和得到context vector。形如圖中的0.1h_s1+0.2h_s2+0.5h_s3+0.2h_s4。如下圖:

image-20201202014847762

接下來,再將context vector與Decoder中上一時刻的輸出進行concat,拼接成新的向量作為這一時刻Decoder的輸入:

image-20201202015307812

至此Attention的主要流程就描述完了,Decoder的輸出就可以去接一個全連線層再過一個softmax就可以拿到詞表大小的概率分佈了,這也就得到了我們的預測值,再將預測值跟標籤值算損失函式作反向梯度更新,再更新網路引數,如此往復,也就是神經網路的常規操作了。

程式碼實現中的各個關鍵tensor的維度

這裡記錄一下一般在程式碼實現整個Seq2Seq Attention時,各個環節各個關鍵tensor的維度,先對各個變數做一下定義說明:

batch_size:資料分批時一批資料的大小;

embedding_dim:輸入對應的embedding的維度,預設Encoder與Decoder一致;

enc_max_len:encoder時輸入的最大長度,一般是資料預處理是,階段或padding之後規範化的長度;

dec_max_len:decoder時輸入的最大長度;

enc_hidden_size:encoder中隱藏層的維度;

dec_hidden_size:decoder中隱藏層的維度;

vocab_size:詞表大小。

各個關鍵tensor維度:

Encoder部分:

Encoder的輸入:enc_input.shape=(batch_size,enc_max_len)

輸入因為是token_id,所以得過一個Embedding層,過完Embedding的輸入:enc_input_embedding.shape=(batch_size,enc_max_len,embedding_dim)【每一個詞都變成了對應的詞向量】

過完Encoder的輸出:enc_output.shape=(batch_size,enc_max_len,enc_hidden_size);對應隱藏層的維度:enc_hidden.shape=(batch_size,enc_hidden_size)

Decoder部分:

原始Decoder的輸入:dec_input.shape=(batch_size,dec_max_len)

第一個Encoder的輸入過Embedding層:dec_input_embedding.shape=(batch_size,1,embedding_dim)

隱藏層:dec_hidden.shape=(batch_size,dec_hidden_size)

Attention中間過程及輸出:

score.shape=(batch_size,enc_max_len,1) = Attention_Weight.shape

Attention_Weight與enc_output作reduce_sum,其實是在中間維度上進行的加權求和,得到的context_vextor維度:context_vextor.shape=(batch_size,enc_hidden_size)

將context_vextor升維後與dec_input_embedding進行concat,得到最終每個RNNs的輸入:RNNs_input.shape=(batch_size,1,embedding_dim+enc_hidden_size)

輸出:dec_output.shape=(batch_size,dec_hidden_size)

如果輸出再過一個FC層,則輸出維度(batch_size,vocab_size)

小結

沒什麼特殊小結,紙上得來終覺淺,絕知此事要躬行。程式碼擼一遍就知道了。改天放上程式碼連結。晚安嘞您。

參考文章:
沒有參考文章,感謝HCT張楠老師的講解。

相關文章