關於attention中對padding的處理:mask

waterrr。發表於2024-05-21

先問了下chatgpt:
我正在學習torch.nn.multiheadattention,請告訴我att_mask和key_padding_mask這兩個引數有什麼不同,分別用於處理什麼問題,以及輸出有什麼不同,並給出程式碼示例

chatgpt的回答:
torch.nn.MultiheadAttention 中的 attn_maskkey_padding_mask 是兩個非常重要的引數,用於處理不同的注意力機制中的問題。

attn_mask (Attention Mask)

attn_mask 用於在計算注意力權重時遮蔽某些位置。這在例如自迴歸生成模型中非常有用,可以防止模型在生成當前詞時看到未來的詞。它可以是一個 2D 的布林型矩陣,用於遮蔽特定位置,或是一個上三角矩陣,用於遮蔽未來的位置。

key_padding_mask

key_padding_mask 用於在計算注意力權重時忽略某些填充位置(padding)。這是在處理不同長度的序列時常用的技巧,確保模型不會將填充值(通常為0)作為有效的資訊來處理。

不同之處

  • attn_mask 是用來遮蔽未來的位置或指定位置,以防止資訊洩露或特定位置的資訊影響計算。
  • key_padding_mask 是用來忽略填充值的位置,確保這些填充值不會影響模型的注意力計算。

程式碼示例

import torch
import torch.nn as nn

# 建立一些示例資料
batch_size = 4
seq_length = 6
embed_dim = 8
num_heads = 2

# 模擬輸入 (batch_size, seq_length, embed_dim)
x = torch.rand(batch_size, seq_length, embed_dim)

# 模擬填充後的序列
padded_x = nn.utils.rnn.pad_sequence(
    [torch.rand(3, embed_dim), torch.rand(4, embed_dim), torch.rand(5, embed_dim), torch.rand(6, embed_dim)], 
    batch_first=True
)

# 建立 SelfAttention 例項
self_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

# Attention mask: 遮蔽未來的位置(上三角矩陣)
attn_mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()

# Key padding mask: 遮蔽填充值(padding)
key_padding_mask = torch.tensor([
    [False, False, False, True, True, True],  # 只有前三個位置有效
    [False, False, False, False, True, True], # 只有前四個位置有效
    [False, False, False, False, False, True],# 只有前五個位置有效
    [False, False, False, False, False, False]# 所有位置都有效
])

# 計算注意力
attn_output, attn_weights = self_attention(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)

print("Attention Output:", attn_output)
print("Attention Weights:", attn_weights)

解釋程式碼

  1. 模擬輸入資料:建立一個隨機的輸入張量 x,形狀為 (batch_size, seq_length, embed_dim)
  2. 模擬填充後的序列:使用 nn.utils.rnn.pad_sequence 建立一個填充後的序列 padded_x
  3. 建立 SelfAttention 例項:例項化 nn.MultiheadAttention,設定 embed_dimnum_heads
  4. 建立 Attention Maskattn_mask 是一個上三角矩陣,用於遮蔽未來的位置。
  5. 建立 Key Padding Maskkey_padding_mask 是一個布林型張量,用於遮蔽填充值的位置。
  6. 計算注意力:呼叫 self_attention,傳入輸入張量、attn_maskkey_padding_mask

輸出

  • Attention Output:經過注意力機制後的輸出張量。
  • Attention Weights:注意力權重,展示了每個位置的權重值。

這樣,你可以透過 attn_maskkey_padding_mask 靈活地控制注意力機制的行為,以適應不同的任務需求。

相關文章