Attention機制全流程詳解與細節學習筆記
回顧
上一篇筆記裡,我們的初學筆記裡,我們已經對Attention機制的主要內容作了全面的介紹,這篇筆記,主要是補充我在第二次學習Attention機制時對一些細節的理解與記錄。下面我會把整個Attention機制的主要流程通過另外一種方式再過一遍,對其中的一些關鍵引數的維度,這裡也會有所標註。溫故而知新,挺好。
Attention的全流程
我們用上面的圖來說明Attention機制的整個工作流程。
首先,是Decoder部分:
原始輸入是語料分詞後的token_id,分批次傳入Embedding層得到詞向量,再將詞向量傳入Encoder中的特徵提取器進行特徵提取,這裡使用的是RNN系列的模型(RNN、LSTM、GRU),用RNNs代稱,為了更好的捕捉一個句子前後的語義特徵,我們這裡使用雙向的RNNs。兩個方向的RNNs所產生的兩部分隱藏層狀態拼接成一個狀態hs進行輸出。這是後面Attention所需要用到的重要狀態值,它包含了各個輸入詞的語義,在普通的Seq2Seq模型中,它就是生成的語義編碼c。
再看Decoder部分:
解碼部分與編碼部分類似,t-1時刻RNNs輸出一個隱藏狀態h_(t-1),該隱藏狀態在圖中要傳下去,與編碼產生的h_s進行計算,得到一個Score:
注意,當前時刻用到的Decoder的隱藏態h_t-1代表是上一時刻Decoder單元輸出的,因為在RNN中,當前時刻的輸入來自於上一時刻的輸出。
這個Score怎麼計算的呢,一般我們常用的兩個方法如下:
方法一:
兩個向量h_t與h_s直接進行點乘得到Score。中間也可以加一個神經網路引數去學習。點乘其實是求了兩個向量的相似度,這看起來也很合理,越相似的權重越大。
方法二:
感知機的方式。可以寫作:v^T tanh(W_h h_i + W_s s_t + b_attn)。
下一步,將Score進行softmax歸一化,得到權重Attention Weight。
再將Attention Weight與h_s作reduce_sum,其實也就是加權求和得到context vector。形如圖中的0.1h_s1+0.2h_s2+0.5h_s3+0.2h_s4。如下圖:
接下來,再將context vector與Decoder中上一時刻的輸出進行concat,拼接成新的向量作為這一時刻Decoder的輸入:
至此Attention的主要流程就描述完了,Decoder的輸出就可以去接一個全連線層再過一個softmax就可以拿到詞表大小的概率分佈了,這也就得到了我們的預測值,再將預測值跟標籤值算損失函式作反向梯度更新,再更新網路引數,如此往復,也就是神經網路的常規操作了。
程式碼實現中的各個關鍵tensor的維度
這裡記錄一下一般在程式碼實現整個Seq2Seq Attention時,各個環節各個關鍵tensor的維度,先對各個變數做一下定義說明:
batch_size:資料分批時一批資料的大小;
embedding_dim:輸入對應的embedding的維度,預設Encoder與Decoder一致;
enc_max_len:encoder時輸入的最大長度,一般是資料預處理是,階段或padding之後規範化的長度;
dec_max_len:decoder時輸入的最大長度;
enc_hidden_size:encoder中隱藏層的維度;
dec_hidden_size:decoder中隱藏層的維度;
vocab_size:詞表大小。
各個關鍵tensor維度:
Encoder部分:
Encoder的輸入:enc_input.shape=(batch_size,enc_max_len)
輸入因為是token_id,所以得過一個Embedding層,過完Embedding的輸入:enc_input_embedding.shape=(batch_size,enc_max_len,embedding_dim)【每一個詞都變成了對應的詞向量】
過完Encoder的輸出:enc_output.shape=(batch_size,enc_max_len,enc_hidden_size);對應隱藏層的維度:enc_hidden.shape=(batch_size,enc_hidden_size)
Decoder部分:
原始Decoder的輸入:dec_input.shape=(batch_size,dec_max_len)
第一個Encoder的輸入過Embedding層:dec_input_embedding.shape=(batch_size,1,embedding_dim)
隱藏層:dec_hidden.shape=(batch_size,dec_hidden_size)
Attention中間過程及輸出:
score.shape=(batch_size,enc_max_len,1) = Attention_Weight.shape
Attention_Weight與enc_output作reduce_sum,其實是在中間維度上進行的加權求和,得到的context_vextor維度:context_vextor.shape=(batch_size,enc_hidden_size)
將context_vextor升維後與dec_input_embedding進行concat,得到最終每個RNNs的輸入:RNNs_input.shape=(batch_size,1,embedding_dim+enc_hidden_size)
輸出:dec_output.shape=(batch_size,dec_hidden_size)
如果輸出再過一個FC層,則輸出維度(batch_size,vocab_size)
小結
沒什麼特殊小結,紙上得來終覺淺,絕知此事要躬行。程式碼擼一遍就知道了。改天放上程式碼連結。晚安嘞您。
參考文章:
沒有參考文章,感謝HCT張楠老師的講解。
相關文章
- 大模型學習筆記:attention 機制大模型筆記
- transformer中的attention機制詳解ORM
- attention注意力機制學習
- 一千行 MySQL 詳細學習筆記(值得學習與收藏)MySql筆記
- MySql學習筆記--詳細整理--下MySql筆記
- 學習筆記(2)IPC機制筆記
- redis學習筆記(詳細)——高階篇Redis筆記
- MySQL 學習筆記(二)MVCC 機制MySql筆記MVC
- 深度學習中的序列模型演變及學習筆記(含RNN/LSTM/GRU/Seq2Seq/Attention機制)深度學習模型筆記RNN
- 【機器學習科學庫】全md文件筆記:Matplotlib詳細使用方法(已分享,附程式碼)機器學習筆記
- Go學習筆記-GMP詳解Go筆記
- 史上最全、最詳細的 kafka 學習筆記!Kafka筆記
- 【吳恩達深度學習筆記】5.3序列模型和注意力機制Sequence models&Attention mechanism吳恩達深度學習筆記模型
- 深度學習中的注意力機制(Attention Model)深度學習
- 【機器學習科學庫】全md檔案筆記:Matplotlib詳細使用方法(已分享,附程式碼)機器學習筆記
- JVM學習筆記——類載入機制JVM筆記
- LevelDB學習筆記 (2): 整體概覽與讀寫實現細節筆記
- Oracle SCN機制詳細解讀Oracle
- SpringBoot + Spring Security 學習筆記(二)安全認證流程原始碼詳解Spring Boot筆記原始碼
- JVM學習筆記(3)---OutOfMemory詳解JVM筆記
- 英語學習詳細筆記(十一)動名詞筆記
- springmvc學習筆記(全)SpringMVC筆記
- Java學習筆記(一)上傳圖片到七牛雲的詳細實現流程Java筆記
- 超詳細!Postman 安裝與漢化全流程教程Postman
- mysql學習筆記-底層原理詳解MySql筆記
- MIT 6.824 學習筆記(一)--- RPC 詳解MIT筆記RPC
- Nginx變數詳解(學習筆記十九)Nginx變數筆記
- JVM記憶體分配機制與回收策略選擇-JVM學習筆記(2)JVM記憶體筆記
- 快速入門NativeScript,超詳細的NativeScript學習筆記筆記
- Kafka超詳細學習筆記【概念理解,安裝配置】Kafka筆記
- 期望 與 機率論 學習筆記筆記
- 程式設計師筆記| 詳解Eureka 快取機制程式設計師筆記快取
- Java註解與反射學習筆記Java反射筆記
- Jenkins學習筆記第八篇pipeline機制Jenkins筆記
- 【機器學習】李宏毅——自注意力機制(Self-attention)機器學習
- 【恩墨學院】深入剖析 - Oracle SCN機制詳細解讀Oracle
- Nginx 快取機制詳解!非常詳細實用Nginx快取
- 學習記錄Spring Boot 記錄配置細節Spring Boot