先問了下chatgpt:
我正在學習torch.nn.multiheadattention,請告訴我att_mask和key_padding_mask這兩個引數有什麼不同,分別用於處理什麼問題,以及輸出有什麼不同,並給出程式碼示例
chatgpt的回答:
torch.nn.MultiheadAttention
中的 attn_mask
和 key_padding_mask
是兩個非常重要的引數,用於處理不同的注意力機制中的問題。
attn_mask
(Attention Mask)
attn_mask
用於在計算注意力權重時遮蔽某些位置。這在例如自迴歸生成模型中非常有用,可以防止模型在生成當前詞時看到未來的詞。它可以是一個 2D 的布林型矩陣,用於遮蔽特定位置,或是一個上三角矩陣,用於遮蔽未來的位置。
key_padding_mask
key_padding_mask
用於在計算注意力權重時忽略某些填充位置(padding)。這是在處理不同長度的序列時常用的技巧,確保模型不會將填充值(通常為0)作為有效的資訊來處理。
不同之處
attn_mask
是用來遮蔽未來的位置或指定位置,以防止資訊洩露或特定位置的資訊影響計算。key_padding_mask
是用來忽略填充值的位置,確保這些填充值不會影響模型的注意力計算。
程式碼示例
import torch
import torch.nn as nn
# 建立一些示例資料
batch_size = 4
seq_length = 6
embed_dim = 8
num_heads = 2
# 模擬輸入 (batch_size, seq_length, embed_dim)
x = torch.rand(batch_size, seq_length, embed_dim)
# 模擬填充後的序列
padded_x = nn.utils.rnn.pad_sequence(
[torch.rand(3, embed_dim), torch.rand(4, embed_dim), torch.rand(5, embed_dim), torch.rand(6, embed_dim)],
batch_first=True
)
# 建立 SelfAttention 例項
self_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
# Attention mask: 遮蔽未來的位置(上三角矩陣)
attn_mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()
# Key padding mask: 遮蔽填充值(padding)
key_padding_mask = torch.tensor([
[False, False, False, True, True, True], # 只有前三個位置有效
[False, False, False, False, True, True], # 只有前四個位置有效
[False, False, False, False, False, True],# 只有前五個位置有效
[False, False, False, False, False, False]# 所有位置都有效
])
# 計算注意力
attn_output, attn_weights = self_attention(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
print("Attention Output:", attn_output)
print("Attention Weights:", attn_weights)
解釋程式碼
- 模擬輸入資料:建立一個隨機的輸入張量
x
,形狀為(batch_size, seq_length, embed_dim)
。 - 模擬填充後的序列:使用
nn.utils.rnn.pad_sequence
建立一個填充後的序列padded_x
。 - 建立 SelfAttention 例項:例項化
nn.MultiheadAttention
,設定embed_dim
和num_heads
。 - 建立 Attention Mask:
attn_mask
是一個上三角矩陣,用於遮蔽未來的位置。 - 建立 Key Padding Mask:
key_padding_mask
是一個布林型張量,用於遮蔽填充值的位置。 - 計算注意力:呼叫
self_attention
,傳入輸入張量、attn_mask
和key_padding_mask
。
輸出
- Attention Output:經過注意力機制後的輸出張量。
- Attention Weights:注意力權重,展示了每個位置的權重值。
這樣,你可以透過 attn_mask
和 key_padding_mask
靈活地控制注意力機制的行為,以適應不同的任務需求。