一步一步手寫GPT

zrq96發表於2024-08-02

本文記錄一下模仿nanoGPT專案,使用自頂向下的程式設計法一步步手寫GPT的過程。閱讀本文需要了解TransformerGPT,和PyTorch的基礎知識。

下面是會用到的所有python庫

import math  # will use math.sqrt
from dataclasses import dataclass # for configuration

import torch
import torch.nn as nn
import torch.nn.functional as F  # F.softmax
from torch import FloatTensor, LongTensor  # for type annotation

首先定義一個總體的GPT框架,這裡只包含最最基本的方法:__init__forward,和generate

  • 這裡把generate也在一開始就考慮進來是因為逼近GPT是一個語言模型,應當先天就支援文字生成文字的方法。注意這裡forwardgenerate的返回值型別不同。forward返回的是下一個詞的logits,而generate則是一部到位生成一串token ids。其他方法,諸如from_pretrained等,反而可以以後再考慮。
  • __init__方法中,我們定義一個GPT的主要部件,即
    • Token Embedding;
    • Position Embedding;
    • Transformer Blocks;以及
    • Language Modelling Head
class MyGPT(nn.Module):
    def __init__(self, conf) -> None:
        super().__init__()
        self.conf = conf

        self.tok_embd = ...
        self.pos_embd = ...
        self.tfm_blocks = ...
        self.lm_head = ...

    def forward(self, x_id: LongTensor) -> FloatTensor:
        '''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
        pass

    def generate(self, x_id: LongTensor, max_new_tokens: int) -> LongTensor:
        '''(padded) sequence of token ids -> sequence of token ids'''
        pass

搭建好主體框架後,我們開始填充細節。但再動手實現這三個方法錢,我們要先思考一下,定義一個GPT模型,需要考慮哪些超引數(hyper-parameters)?我們這裡只考慮最主要的超參,也就是

  • 詞彙表大小
  • 詞嵌入的維度
  • transformer層(塊)數
  • 上下文長度
  • 使用多頭注意機制的話,還得考慮用幾個頭
@dataclass
class MyGPTConfig:
    '''minimum hyper-parameters'''
    vocab_size: int = 50257
    ctx_window: int = 1024
    dim_embd: int = 768
    n_attn_head: int = 12
    n_tfm_block: int = 12
    # Dropout probability 
    dropout: float = 0.0

所有GPT部件中,最重要也最複雜的是transformer blocks,這個我們最後來實現,先搞定其他簡單的部件和方法。token embedding 和 position embedding都是普通的nn.Embedding,而language modelling head也只是一個普通線性變換,根據前面所有token的contextual embedding輸出後面一個token的logits(logarithmic probabilities)

class MyGPT(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf

        self.tok_embd = nn.Embedding(conf.vocab_size, conf.dim_embd)
        self.pos_embd = nn.Embedding(conf.ctx_window, conf.dim_embd)
        self.tfm_blocks = nn.ModuleList([MyTransformerBlock(conf) for _ in range(conf.n_tfm_block)])
        self.lm_head = nn.Linear(conf.dim_embd, conf.vocab_size)

    def forward(self, x_id: LongTensor) -> FloatTensor:
        '''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
        pass

    def generate(self, x_id: LongTensor, max_new_tokens: int) -> LongTensor:
        '''(padded) sequence of token ids -> sequence of token ids'''
        pass

forward方法要做的事情就是把輸入的token ids先嵌入成詞向量,以及生成一批編碼位置的向量,將這批向量喂入Transformer Blocks得到這串輸入的contextual embedding,最後預測出下一個token的logits。根據經驗,使用DropoutLayerNorm機制可以廣泛地提高神經網路模型地型能,所以在__init__方法加入這兩個部件,是forward中也注意使用一下。

class MyGPT(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf

        self.tok_embd = nn.Embedding(conf.vocab_size, conf.dim_embd)
        self.pos_embd = nn.Embedding(conf.ctx_window, conf.dim_embd)
        self.tfm_blocks = nn.ModuleList([MyTransformerBlock(conf) for _ in range(conf.n_tfm_block)])
        self.lm_head = nn.Linear(conf.dim_embd, conf.vocab_size)

        self.dropout = nn.Dropout(conf.dropout)
        self.layer_norm = nn.LayerNorm(conf.dim_embd)

    def forward(self, x_id: LongTensor) -> FloatTensor:
        '''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
        pos = torch.arange(x_id.shape[1], device=x_id.device)
        tok_embd = self.tok_embd(x_id)
        pos_embd = self.pos_embd(pos)
        x_embd = tok_embd + pos_embd
        x_embd = self.dropout(x_embd)
        for tfm_block in self.tfm_blocks:
            x_embd = tfm_block(x_embd)
        x_embd = self.layer_norm(x_embd)
        # note: using list [-1] to preserve the time dimension
        logits = self.lm_head(x_embd[:, [-1], :])

        return logits

    def generate(self, x_id: LongTensor, max_new_tokens: int) -> LongTensor:
        '''(padded) sequence of token ids -> sequence of token ids'''
        pass

generate方法使用forward根據輸入序列一步步生成一段長序列:當forward預測了下一個token地logits後,選擇機率最大的一個詞加入到原序列中,然後把這個新生成的序列當作新的輸入繼續生成一個詞,直到達到最大值為之。這裡需要注意一點,如果生成的序列太長了,超過了整個GPT的上下文長度,需要截掉一點前面舊的序列,只保留最近(最右側)的序列。

class MyGPT(nn.Module):
    ...
    def generate(self, x_id: LongTensor, max_new_tokens: int) -> LongTensor:
        '''(padded) sequence of token ids -> sequence of token ids'''
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at ctx_window
            x_id_cond = x_id if x_id.size(1) <= self.conf.ctx_window else x_id[:, -self.conf.ctx_window:]
            logits = self.forward(x_id_cond)
            new_tok_id = logits.argmax(dim=-1)
            x_id = torch.cat([x_id, new_tok_id], dim=1)
        return x_id

不過,這種生成方法往往效果不好,實際上我們要加入一點隨機性:

  1. 得到logits後,並不一定選取機率最大的一次詞,而是根據logits隨機取樣一次詞加入到原序列中;
  2. 在取樣前,可以先等比例縮放logits,調整一下這個“隨機性”到底多麼隨機。在實現時,我們使用一個temperature引數控制logits轉換成機率的過程probs = softmax(logits/temperature)。當temperature越大,得到的每個token的probs越均勻,也就是說隨機性越大。
  3. 我們可以每次不止選擇一個詞,而是選擇k個詞,同時儲存k個序列,最後再返回總體機率最大的哪一個序列

所以,最後我們的generate方法會多幾個引數,並且比原始版本複雜些

class MyGPT(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf

        self.tok_embd = nn.Embedding(conf.vocab_size, conf.dim_embd)
        self.pos_embd = nn.Embedding(conf.ctx_window, conf.dim_embd)
        self.tfm_blocks = nn.ModuleList([MyTransformerBlock(conf) for _ in range(conf.n_tfm_block)])
        self.lm_head = nn.Linear(conf.dim_embd, conf.vocab_size)

        self.dropout = nn.Dropout(conf.dropout)
        self.layer_norm = nn.LayerNorm(conf.dim_embd)

    def forward(self, x_id: LongTensor) -> FloatTensor:
        '''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
        pos = torch.arange(x_id.shape[1], device=x_id.device)
        tok_embd = self.tok_embd(x_id)
        pos_embd = self.pos_embd(pos)
        x_embd = tok_embd + pos_embd
        x_embd = self.dropout(x_embd)
        for tfm_block in self.tfm_blocks:
            x_embd = tfm_block(x_embd)
        x_embd = self.layer_norm(x_embd)
        # note: using list [-1] to preserve the time dimension
        logits = self.lm_head(x_embd[:, [-1], :])

        return logits

    def generate(self, x_id: LongTensor, max_new_tokens: int, temperature=1.0, top_k:int=1) -> LongTensor:
        '''(padded) sequence of token ids -> sequence of token ids'''
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at ctx_window
            x_id_cond = x_id if x_id.size(1) <= self.conf.ctx_window else x_id[:, -self.conf.ctx_window:]
            logits = self.forward(x_id_cond)
            logits = logits[:, -1, :] / temperature
            if top_k > 1:
                v, _ = torch.topk(logits, min(top_k, logits.shape[-1])) # top_k cannot exceed vocab_size
                # logits less then top_k are set to -Inf, thus probabilities of those tokens become 0
                logits[logits < v[:, [-1]]] = -float('Inf')            

            probs = F.softmax(logits, dim=-1)
            new_tok_id = torch.multinomial(probs, num_samples=1)
            x_id = torch.cat([x_id, new_tok_id], dim=1)

        return x_id

接下來時我們的重頭戲,實現最主要的部件transformer block,也就是程式碼中的MyTransformerBlock。這個部件也可以拆分為兩個子部件,其中核心的部件自然是自注意力機制self-attention,以及後面的非線性變化——一個簡單的MLP。資料會先後流過這兩個子部件,然後流向下一個block。當然,資料流入這兩個子部件前可以(應當)先layer norm一下。

class MyTransformerBlock(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()

        self.ln1 = nn.LayerNorm(conf.dim_embd)
        self.attn = MyMultiHeadAttention(conf)
        self.ln2 = nn.LayerNorm(conf.dim_embd)
        self.mlp = MyMLP(conf)

    def forward(self, x: FloatTensor) -> FloatTensor:
        '''[batch_size, seq_len, dim_embd] -> [batch_size, seq_len, dim_embd]'''
        x = x + self.attn(self.ln1(x))  # layer norm + attention + residual
        x = x + self.mlp(self.ln2(x))   # layer norm + MLP + residual
        return x

這其中的後者MLP模組可以實現成一個兩層的全連線前饋網路

class MyMLP(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:  
        super().__init__()
        self.fc1 = nn.Linear(conf.dim_embd, conf.dim_embd * 4)  
        # here output dimension (dim_embd*4) is set arbitrary
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(conf.dim_embd * 4, conf.dim_embd)
        self.dropout = nn.Dropout(conf.dropout)

    def forward(self, x: FloatTensor) -> FloatTensor:
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

而最核心的注意力機制,我們也先做一個大致的設計:

class MyMultiHeadAttention(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf
    
    def forward(self, x: FloatTensor) -> FloatTensor:
        pass

注意力機制又是什麼意思?在日常的語境裡,注意就是集中大部分資源(時間、心力)處理小部分重要的/相關的資訊,而只用小部分資源處理其他大部分不重要/不想關的資訊。 使用文件檢索做比喻,就是給定一個query,找到那些與這個query高度相關的文件K,然後集中資源處理這些文件的內容V。

在數學上,我們可以把上面說到的注意力機制設計成這麼一個函式:

\[Attention(Q,K,V) = Softmax(QK^{\top}/d_k)V \]

這裡面\(Q\)是一批query矩陣,\(K\)是一批文件的矩陣,\(V\)是這些文件的內容。\(QK^{\top}/d_k\)可以理解為計算出每一個文件與給定query的相似度,使用\(Softmax\)是吧這些相似度壓縮到0-1之間,得到應該對每一個文件分配多少資源,最後乘以矩陣V就是根據注意力處理文件內容了。在使用GPT處理序列時,序列中的每一個token既是query,也是文件k和文件內容v。當我們處理某一個token時,序列中包括該token在內的所有token都是我們需要處理的內容,這也是注意力機制名稱的由來。翻譯成程式碼就是:

class MyMultiHeadAttention(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf

    def atention(self, q, k, v: FloatTensor) -> FloatTensor:
        d_k = k.shape[-1]
        return F.softmax(q @ k.mT / math.sqrt(d_k), dim=-1) @ v
        
    def forward(self, x: FloatTensor) -> FloatTensor:
        q, k, v = self.make_QKV(x)
        return self.attention(q, k, v)

在使用auto-regressive方式進行建模/生成的過程中,我們要確保每一個token只能注意到它前(左)邊的token,而不能注意到後(右)邊的token,而在程式設計中我們可以使用一個類似mask=[True,True,...,False,False]的技巧將尚未建模/生成的後邊子序列的注意力分數設定為0(softmax前的無窮小)。另外,在實際使用的時候我們也可以dropout一下。那麼升級版的attention函式就變成了

class MyMultiHeadAttention(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf
        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = torch.tril(torch.ones(conf.ctx_win, conf.ctx_win)).view(1, 1, conf.ctx_win, conf.ctx_win)
        self.register_buffer("bias", bias)
        self.attn_dropout = nn.Dropout(conf.dropout)
        self.resid_dropout = nn.Dropout(conf.dropout)

    def attention(self, q, k, v, mask) -> FloatTensor:
        scores = q @ k.mT / math.sqrt(k.shape[-1])
        # ensure that attention is only applied to the left
        scores = scores.masked_fill(mask==False, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        attn = attn @ v
        return attn
        
    def forward(self, x: FloatTensor) -> FloatTensor:
        q, k, v, mask = self.make_QKV(x)
        seq_len = x.shape[1]
        mask = self.bias[:, :, :seq_len, :seq_len]
        y = self.attention(q, k, v)
        y = self.resid_dropout(y)
        return y

那麼,自注意力機制中的矩陣\(Q,K,V\)是怎麼得來的?在GPT中,他們都是把輸入向量\(x\)線性對映到三個不同的空間得到的,也就是透過學習出三個矩陣\(W_Q,W_K,W_Q\),然後用這仨矩陣乘以輸入向量\(x\)。處於編碼的方便(也是通常的做法),我們這裡設\(W_Q,W_K,W_Q\)的維度都是dim_embd

class MyMultiHeadAttention(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf
        self.W_q = nn.Linear(conf.dim_embd, conf.dim_embd)
        self.W_k = nn.Linear(conf.dim_embd, conf.dim_embd)
        self.W_v = nn.Linear(conf.dim_embd, conf.dim_embd)

        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = torch.tril(torch.ones(conf.ctx_win, conf.ctx_win)).view(1, 1, conf.ctx_win, conf.ctx_win)
        self.register_buffer("bias", bias)

        self.attn_dropout = nn.Dropout(conf.dropout)
        self.resid_dropout = nn.Dropout(conf.dropout)

    def attention(self, q, k, v, mask) -> FloatTensor:
        scores = q @ k.mT / math.sqrt(k.shape[-1])
        # ensure that attention is only applied to the left
        scores = scores.masked_fill(mask==0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        attn = attn @ v
        return attn
    
    def make_QKV(self, x):  
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)
        return q, k, v
        
    def forward(self, x: FloatTensor) -> FloatTensor:
        q, k, v = self.make_QKV(x)
        seq_len = x.shape[1]
        mask = self.bias[:, :, :seq_len, :seq_len]
        y = self.attention(q, k, v, mask)
        y = self.resid_dropout(y)
        return y

在Transformer中,我們不止使用一個注意力,而是多個注意力,直覺上這也是有多種注意角度,或者說注意一個序列的不同方面,最後把各個注意頭的輸出拼接起來再做一次線性變換(乘以一個矩陣\(W^O\)):

\[MultiHead(Q, K, V ) = Concat(head_1, ..., head_h)W^O \]

其中\(head_i = Attention(QW_i^Q,KW_i^K,VW_i^V)\)

將這些想法整合起來,最後的多頭注意力機制的程式碼就是

class MyMultiHeadAttention(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf

        self.W_q = nn.Linear(conf.dim_embd, conf.dim_embd)
        self.W_k = nn.Linear(conf.dim_embd, conf.dim_embd)
        self.W_v = nn.Linear(conf.dim_embd, conf.dim_embd)
        self.W_out = nn.Linear(conf.dim_embd, conf.dim_embd)

        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = torch.tril(torch.ones(conf.ctx_win, conf.ctx_win)).view(1, 1, conf.ctx_win, conf.ctx_win)
        self.register_buffer("bias", bias)

        self.attn_dropout = nn.Dropout(conf.dropout)
        self.resid_dropout = nn.Dropout(conf.dropout)

    def attention(self, q, k, v, mask) -> FloatTensor:
        scores = q @ k.mT / math.sqrt(k.shape[-1])
        # ensure that attention is only applied to the left
        scores = scores.masked_fill(mask==0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        attn = attn @ v
        return attn
    
    def make_QKV(self, x, batch_size, seq_len, n_attn_head, dim_embd):
        q = self.W_q(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
        k = self.W_k(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
        return q, k, v
    
    def forward(self, x: FloatTensor) -> FloatTensor:
        batch_size, seq_len, dim_embd = x.shape
        n_attn_head, dim_embd = self.conf.n_attn_head, self.conf.dim_embd
        q, k, v = self.make_QKV(x, batch_size, seq_len, n_attn_head, dim_embd)
        mask = self.bias[:, :, :seq_len, :seq_len]
        y = self.attention(q, k, v, mask)
        y = y.transpose(1, 2).contiguous().view(batch_size, seq_len, dim_embd) # re-assemble all head outputs side by side
        y = self.W_out(y) # output projection
        y = self.resid_dropout(y)
        return y

至此,我們一步一步手寫出了一個GPT了。完整程式碼如下

'''
Hand-made GPT, adapted from [nanoGPT](https://github.com/karpathy/nanoGPT/)
'''
import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, FloatTensor


@dataclass
class MyGPTConfig:
    vocab_size: int = 50257
    ctx_win: int = 1024
    dim_embd: int = 768
    n_attn_head: int = 12
    n_tfm_block: int = 12
    dropout: float = 0.0
    use_bias: bool = True


class MyMLP(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:  
        super().__init__()
        self.fc1 = nn.Linear(conf.dim_embd, conf.dim_embd * 4)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(conf.dim_embd * 4, conf.dim_embd)
        self.dropout = nn.Dropout(conf.dropout)

    def forward(self, x: FloatTensor) -> FloatTensor:
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


class MyMultiHeadAttention(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf

        self.W_q = nn.Linear(conf.dim_embd, conf.dim_embd)
        self.W_k = nn.Linear(conf.dim_embd, conf.dim_embd)
        self.W_v = nn.Linear(conf.dim_embd, conf.dim_embd)
        self.W_out = nn.Linear(conf.dim_embd, conf.dim_embd)

        # causal mask to ensure that attention is only applied to the left in the input sequence
        bias = torch.tril(torch.ones(conf.ctx_win, conf.ctx_win)).view(1, 1, conf.ctx_win, conf.ctx_win)
        self.register_buffer("bias", bias)

        self.attn_dropout = nn.Dropout(conf.dropout)
        self.resid_dropout = nn.Dropout(conf.dropout)

    def attention(self, q, k, v, mask) -> FloatTensor:
        scores = q @ k.mT / math.sqrt(k.shape[-1])
        # ensure that attention is only applied to the left
        scores = scores.masked_fill(mask==0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        attn = attn @ v
        return attn
    
    def make_QKV(self, x, batch_size, seq_len, n_attn_head, dim_embd):
        q = self.W_q(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
        k = self.W_k(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
        v = self.W_v(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
        return q, k, v
    
    def forward(self, x: FloatTensor) -> FloatTensor:
        batch_size, seq_len, dim_embd = x.shape
        n_attn_head, dim_embd = self.conf.n_attn_head, self.conf.dim_embd
        q, k, v = self.make_QKV(x, batch_size, seq_len, n_attn_head, dim_embd)
        mask = self.bias[:, :, :seq_len, :seq_len]
        y = self.attention(q, k, v, mask)
        y = y.transpose(1, 2).contiguous().view(batch_size, seq_len, dim_embd) # re-assemble all head outputs side by side
        y = self.W_out(y) # output projection
        y = self.resid_dropout(y)
        return y


class MyTransformerBlock(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()

        self.ln1 = nn.LayerNorm(conf.dim_embd)
        self.attn = MyMultiHeadAttention(conf)
        self.ln2 = nn.LayerNorm(conf.dim_embd)
        self.mlp = MyMLP(conf)

    def forward(self, x: FloatTensor) -> FloatTensor:
        '''[batch_size, seq_len, dim_embd] -> [batch_size, seq_len, dim_embd]'''
        x = x + self.attn(self.ln1(x))  # layer norm + attention + residual
        x = x + self.mlp(self.ln2(x))   # layer norm + MLP + residual
        return x


class MyGPT(nn.Module):
    def __init__(self, conf: MyGPTConfig) -> None:
        super().__init__()
        self.conf = conf
        self.tok_embd = nn.Embedding(conf.vocab_size, conf.dim_embd)
        self.pos_embd = nn.Embedding(conf.ctx_win, conf.dim_embd)
        self.tfm_blocks = nn.ModuleList([MyTransformerBlock(conf) for _ in range(conf.n_tfm_block)])
        self.lm_head = nn.Linear(conf.dim_embd, conf.vocab_size)

        self.dropout = nn.Dropout(conf.dropout)
        self.layer_norm = nn.LayerNorm(conf.dim_embd)

    def forward(self, x_id: LongTensor) -> FloatTensor:
        '''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
        pos = torch.arange(x_id.shape[1], device=x_id.device)
        tok_embd = self.tok_embd(x_id)
        pos_embd = self.pos_embd(pos)
        x_embd = tok_embd + pos_embd
        x_embd = self.dropout(x_embd)
        for tfm_block in self.tfm_blocks:
            x_embd = tfm_block(x_embd)
        x_embd = self.layer_norm(x_embd)
        # note: using list [-1] to preserve the time dimension
        logits = self.lm_head(x_embd[:, [-1], :])

        return logits

    def generate(self, x_id: LongTensor, max_new_tokens: int, temperature=1.0, top_k:int=1) -> LongTensor:
        '''(padded) sequence of token ids -> sequence of token ids'''
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at ctx_win
            x_id_cond = x_id if x_id.size(1) <= self.conf.ctx_win else x_id[:, -self.conf.ctx_win:]
            logits = self.forward(x_id_cond)
            logits = logits[:, -1, :] / temperature
            if top_k > 1:
                v, _ = torch.topk(logits, min(top_k, logits.shape[-1])) # top_k cannot exceed vocab_size
                logits[logits < v[:, [-1]]] = -float('Inf')            # logits less then top_k are set to -Inf, thus probabilites of those tokens become 0

            probs = F.softmax(logits, dim=-1)
            new_tok_id = torch.multinomial(probs, num_samples=1)
            x_id = torch.cat([x_id, new_tok_id], dim=1)

        return x_id

if __name__ == '__main__':
    conf = MyGPTConfig()
    model = MyGPT(conf)
    print(model)
    inp = torch.LongTensor([
        [1, 2, 3, 4, 5],
        [6, 7, 8, 9, 10],
        ])
    print(model(inp))
    print(model.generate(inp, max_new_tokens=3))

相關文章