Transformer的Pytorch實現【1】

iSherryZhang發表於2024-10-12

使用Pytorch手把手搭建一個Transformer網路結構並完成一個小型翻譯任務。

首先,對Transformer結構進行拆解,Transformer由編碼器和解碼器(Encoder-Decoder)組成,編碼器由Multi-Head Attention + Feed-Forward Network組成的結構堆疊而成,解碼器由Multi-Head Attention + Multi-Head Attention + Feed-Forward Network組成的結構堆疊而成。
image

class Encoder(nn.Module):
    def __init__(self, corpus) -> None:
        super().__init__()
        self.src_emb = nn.Embedding(len(corpus.src_vocab), d_embedding) # word embedding
        self.pos_emb = nn.Embedding.from_pretrained(get_sin_enc_table(corpus.src_len + 1, d_embedding), freeze=True) # position embedding
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(encoder_n_layers)])
    
    def forward(self, enc_inputs):
        pos_indices = torch.arange(1, enc_inputs.size(1)+1).unsqueeze(0).to(enc_inputs)
        enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(pos_indices)
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)
        enc_self_attn_weights = []
        for layer in self.layers:
            enc_outputs, enc_self_attn_weight = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attn_weights.append(enc_self_attn_weight)
        return enc_outputs, enc_self_attn_weights

class Decoder(nn.Module):
    def __init__(self, corpus) -> None:
        super().__init__()
        self.tgt_emb = nn.Embedding(len(corpus.tgt_vocab), d_embedding) # word embedding
        self.pos_emb = nn.Embedding.from_pretrained(get_sin_enc_table(corpus.tgt_len + 1, d_embedding), freeze=True) # position embedding
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(decoder_n_layers)])
    
    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        pos_indices = torch.arange(1, dec_inputs.size(1)+1).unsqueeze(0).to(dec_inputs)
        dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(pos_indices)
        # 生成填充掩碼
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs)
        # 生成後續掩碼
        dec_self_attn_subsequent_mask= get_attn_subsequent_mask(dec_inputs)
        # 整合掩碼
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # 自注意力機制只有填充掩碼,且是根據encoder和decoder的輸入生成的

        dec_self_attn_weights = []
        dec_enc_attn_weights = []
        for layer in self.layers:
            dec_outputs, dec_self_attn_weight, dec_enc_attn_weight = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attn_weights.append(dec_self_attn_weight)
            dec_enc_attn_weights.append(dec_enc_attn_weight)
        return dec_outputs, dec_self_attn_weights, dec_enc_attn_weights

class Transformer(nn.Module):
    def __init__(self, corpus) -> None:
        super().__init__()
        self.encoder = Encoder(corpus)
        self.decoder = Decoder(corpus)
        self.projection = nn.Linear(d_embedding, len(corpus.tgt_vocab), bias=False)
    
    def forward(self, enc_inputs, dec_inputs):
        enc_outputs, enc_self_attn_weights = self.encoder(enc_inputs)
        dec_outputs, dec_self_attn_weights, dec_enc_attn_weights = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs)
        return dec_logits, enc_self_attn_weights, dec_self_attn_weights, dec_enc_attn_weights

很直接的,我們可以看到,要實現Transformer需要實現兩個基本結構:Multi-Head Attention + Feed-Forward Network。

Multi-Head Attention

要實現多頭注意力機制,首先要實現注意力機制。

Attention的計算:

  1. 對輸入進行線性變換,得到QKV矩陣
  2. QK點積、縮放、softmax
  3. 再對V進行加權求和

Multi-Head Attention就是包含多個Attention頭:

  1. 多個頭進行concat
  2. 連線全連線層,使得Multi-Head Attention得到的輸出與輸入相同

image

我們來手把手走一下Multi-Head Attention的計算:


假設輸入序列的長度為n,針對每個token的編碼長度為d,則輸入為(n, d)

權重矩陣:$ W_Q: (d, d_q), W_K: (d, d_q), W_V:(d, d_v)
$

  1. 得到的QKV分別為:$ Q: (n, d_q), K: (n, d_q), V:(n, d_v)
    $
  2. Q與K的轉置相乘:$ Q \cdot K^T : (n, d_q) \cdot (d_q, n) = (n, n) $,每一個點的值代表第i個token和第j個token的相似度
  3. 縮放:不改變矩陣的尺寸,只改變矩陣中的值
  4. softmax:對矩陣中的值進行歸一化
  5. 對V做加權求和:$ softmax(\frac {Q \cdot K^T} {\sqrt{d_k}})\cdot V = (n, n)\cdot(n, d_v) = (n, d_v) $
  6. 針對一個$ (n, d) \(的輸入,單頭得到的輸出為\) (n, d_v) \(, 多頭concat得到的輸出就是\) (n_{heads}, n, d_v) $
  7. transpose並進行fully-connection運算: $ (n_{heads}, n, d) -> (n, n_{heads}*d_v) -> (n, d) $

程式碼實現如下:

class MultiHeadAttention(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.W_Q = nn.Linear(d_embedding, d_k * n_heads)
        self.W_K = nn.Linear(d_embedding, d_k * n_heads)
        self.W_V = nn.Linear(d_embedding, d_v * n_heads)
        self.linear = nn.Linear(n_heads * d_v, d_embedding)
        self.layer_norm = nn.LayerNorm(d_embedding)
    
    def forward(self, Q, K, V, attn_mask):
        '''
            Q: [batch, len_q, d_embedding]
            K: [batch, len_k, d_embedding]
            V: [batch, len_v, d_embedding]
            attn_mask: [batch, len_q, len_k]
        '''
        residual, batch_size = Q, Q.size(0)
        # step1: 對輸入進行線性變換 + 重塑
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2) # [batch, n_heads, len_q, d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2) # [batch, n_heads, len_k, d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2) # [batch, n_heads, len_v, d_v]
        # step2: 計算注意力分數, 點積 + 縮放
        scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d_k) # [batch_size, n_heads, len_q, len_k]
        # step3: 使用注意力掩碼, 將mask值為1處的權重替換為極小值
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9)
        # step4: 對注意力分數進行歸一化
        weights = nn.Softmax(dim=-1)(scores)
        # step5: 計算上下文向量,對V進行加權求和
        context = torch.matmul(weights, v_s) # [batch_size, n_heads, len_q, dim_v]
        # step6: fc
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # [batch_size, len_q, n_heads * dim_v]
        output = self.linear(context) # [batch_size, len_q, d_embedding]
        # step7: layernorm
        output = self.layer_norm(output + residual)
        return output, weights

Feed-Forward Network

在Encoder和Decoder的每個注意力層後面都會接一個Position-Wise Feed-Forward Network,起到進一步提取特徵的作用。這個過程在輸入序列上的每個位置都是獨立完成的,不打亂,不整合,不迴圈,因此稱為Position-Wise Feed-Forward。

計算公式為:

$ F(x) = max(0, W_1x+b_1)*W_2+b_2 $

計算過程如圖所示,使用conv1/fc先將輸入序列對映到更高維度(d_ff是一個可調節的超引數,一般是4倍的d),然後再將對映後的序列降維到原始維度。

image

使用conv1d的實現如下

nn.Conv1d(in_channels, out_channels, kernel_size, ...)

$ (batch, n, d)-> (batch, d, n) -> (batch, d_ff, n) -> (batch, d, n) -> (batch, n, d) $

第一個conv1d的引數為:

nn.Conv1d(d, d_ff, 1, ...)

第二個conv1d的引數為:

nn.Conv1d(d_ff, d, 1, ...)

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_ff=2048) -> None:
        super().__init__()
        # 定義一個一維卷積層,將輸入對映到更高維度
        self.conv1 = nn.Conv1d(in_channels=d_embedding, out_channels=d_ff, kernel_size=1)
        # 定義一個一維卷積層,將輸入對映回原始維度
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_embedding, kernel_size=1)
        self.layer_norm = nn.LayerNorm(d_embedding)

    def forward(self, inputs):
        '''
            inputs: [batch_size, len_q, embedding_dim]
            output: [batch_size, len_q, embedding_dim]
        '''
        residual = inputs
        output = self.conv1(inputs.transpose(1, 2))
        output = nn.ReLU()(output)
        output = self.conv2(output)
        output = self.layer_norm(output.transpose(1, 2) + residual)
        return output

使用fc的實現如下

nn.Linear(in_features, out_features, bias=True)

$ (batch, n, d)-> (batch, n, d_ff) -> (batch, n, d) $

第一個fc的引數為:

nn.Linear(d, d_ff, bias=True)

第一個fc的引數為:

nn.Linear(d_ff, d, bias=True)

class PoswiseFeedForwardNet_fc(nn.Module):
    def __init__(self, d_ff=2048) -> None:
        super().__init__()
        # 定義一個一維卷積層,將輸入對映到更高維度
        self.fc1 = nn.Linear(d_embedding, d_ff, bias=True)
        self.fc2 = nn.Linear(d_ff, d_embedding, bias=True)
        # self.conv1 = nn.Conv1d(in_channels=d_embedding, out_channels=d_ff, kernel_size=1)
        # 定義一個一維卷積層,將輸入對映回原始維度
        # self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_embedding, kernel_size=1)
        self.layer_norm = nn.LayerNorm(d_embedding)

    def forward(self, inputs):
        '''
            inputs: [batch_size, len_q, embedding_dim]
            output: [batch_size, len_q, embedding_dim]
        '''
        residual = inputs
        output = self.fc1(inputs)
        output = nn.ReLU()(output)
        output = self.fc2(output)
        output = self.layer_norm(output + residual)
        return output

參考連結:

GPT圖解

相關文章