本文記錄一下模仿nanoGPT專案,使用自頂向下的程式設計法一步步手寫GPT的過程。閱讀本文需要了解Transformer,GPT,和PyTorch的基礎知識。
下面是會用到的所有python庫
import math # will use math.sqrt
from dataclasses import dataclass # for configuration
import torch
import torch.nn as nn
import torch.nn.functional as F # F.softmax
from torch import FloatTensor, LongTensor # for type annotation
首先定義一個總體的GPT框架,這裡只包含最最基本的方法:__init__
,forward
,和generate
。
- 這裡把
generate
也在一開始就考慮進來是因為逼近GPT是一個語言模型,應當先天就支援文字生成文字的方法。注意這裡forward
和generate
的返回值型別不同。forward
返回的是下一個詞的logits,而generate
則是一部到位生成一串token ids。其他方法,諸如from_pretrained
等,反而可以以後再考慮。 - 在
__init__
方法中,我們定義一個GPT的主要部件,即- Token Embedding;
- Position Embedding;
- Transformer Blocks;以及
- Language Modelling Head
class MyGPT(nn.Module):
def __init__(self, conf) -> None:
super().__init__()
self.conf = conf
self.tok_embd = ...
self.pos_embd = ...
self.tfm_blocks = ...
self.lm_head = ...
def forward(self, x_id: LongTensor) -> FloatTensor:
'''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
pass
def generate(self, x_id: LongTensor, max_new_tokens: int) -> LongTensor:
'''(padded) sequence of token ids -> sequence of token ids'''
pass
搭建好主體框架後,我們開始填充細節。但再動手實現這三個方法錢,我們要先思考一下,定義一個GPT模型,需要考慮哪些超引數(hyper-parameters)?我們這裡只考慮最主要的超參,也就是
- 詞彙表大小
- 詞嵌入的維度
- transformer層(塊)數
- 上下文長度
- 使用多頭注意機制的話,還得考慮用幾個頭
@dataclass
class MyGPTConfig:
'''minimum hyper-parameters'''
vocab_size: int = 50257
ctx_window: int = 1024
dim_embd: int = 768
n_attn_head: int = 12
n_tfm_block: int = 12
# Dropout probability
dropout: float = 0.0
所有GPT部件中,最重要也最複雜的是transformer blocks,這個我們最後來實現,先搞定其他簡單的部件和方法。token embedding 和 position embedding都是普通的nn.Embedding
,而language modelling head也只是一個普通線性變換,根據前面所有token的contextual embedding輸出後面一個token的logits(logarithmic probabilities)
class MyGPT(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
self.tok_embd = nn.Embedding(conf.vocab_size, conf.dim_embd)
self.pos_embd = nn.Embedding(conf.ctx_window, conf.dim_embd)
self.tfm_blocks = nn.ModuleList([MyTransformerBlock(conf) for _ in range(conf.n_tfm_block)])
self.lm_head = nn.Linear(conf.dim_embd, conf.vocab_size)
def forward(self, x_id: LongTensor) -> FloatTensor:
'''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
pass
def generate(self, x_id: LongTensor, max_new_tokens: int) -> LongTensor:
'''(padded) sequence of token ids -> sequence of token ids'''
pass
forward
方法要做的事情就是把輸入的token ids先嵌入成詞向量,以及生成一批編碼位置的向量,將這批向量喂入Transformer Blocks得到這串輸入的contextual embedding,最後預測出下一個token的logits。根據經驗,使用Dropout和LayerNorm機制可以廣泛地提高神經網路模型地型能,所以在__init__
方法加入這兩個部件,是forward
中也注意使用一下。
class MyGPT(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
self.tok_embd = nn.Embedding(conf.vocab_size, conf.dim_embd)
self.pos_embd = nn.Embedding(conf.ctx_window, conf.dim_embd)
self.tfm_blocks = nn.ModuleList([MyTransformerBlock(conf) for _ in range(conf.n_tfm_block)])
self.lm_head = nn.Linear(conf.dim_embd, conf.vocab_size)
self.dropout = nn.Dropout(conf.dropout)
self.layer_norm = nn.LayerNorm(conf.dim_embd)
def forward(self, x_id: LongTensor) -> FloatTensor:
'''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
pos = torch.arange(x_id.shape[1], device=x_id.device)
tok_embd = self.tok_embd(x_id)
pos_embd = self.pos_embd(pos)
x_embd = tok_embd + pos_embd
x_embd = self.dropout(x_embd)
for tfm_block in self.tfm_blocks:
x_embd = tfm_block(x_embd)
x_embd = self.layer_norm(x_embd)
# note: using list [-1] to preserve the time dimension
logits = self.lm_head(x_embd[:, [-1], :])
return logits
def generate(self, x_id: LongTensor, max_new_tokens: int) -> LongTensor:
'''(padded) sequence of token ids -> sequence of token ids'''
pass
generate
方法使用forward
根據輸入序列一步步生成一段長序列:當forward
預測了下一個token地logits後,選擇機率最大的一個詞加入到原序列中,然後把這個新生成的序列當作新的輸入繼續生成一個詞,直到達到最大值為之。這裡需要注意一點,如果生成的序列太長了,超過了整個GPT的上下文長度,需要截掉一點前面舊的序列,只保留最近(最右側)的序列。
class MyGPT(nn.Module):
...
def generate(self, x_id: LongTensor, max_new_tokens: int) -> LongTensor:
'''(padded) sequence of token ids -> sequence of token ids'''
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at ctx_window
x_id_cond = x_id if x_id.size(1) <= self.conf.ctx_window else x_id[:, -self.conf.ctx_window:]
logits = self.forward(x_id_cond)
new_tok_id = logits.argmax(dim=-1)
x_id = torch.cat([x_id, new_tok_id], dim=1)
return x_id
不過,這種生成方法往往效果不好,實際上我們要加入一點隨機性:
- 得到logits後,並不一定選取機率最大的一次詞,而是根據logits隨機取樣一次詞加入到原序列中;
- 在取樣前,可以先等比例縮放logits,調整一下這個“隨機性”到底多麼隨機。在實現時,我們使用一個temperature引數控制logits轉換成機率的過程
probs = softmax(logits/temperature)
。當temperature越大,得到的每個token的probs越均勻,也就是說隨機性越大。 - 我們可以每次不止選擇一個詞,而是選擇k個詞,同時儲存k個序列,最後再返回總體機率最大的哪一個序列
所以,最後我們的generate
方法會多幾個引數,並且比原始版本複雜些
class MyGPT(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
self.tok_embd = nn.Embedding(conf.vocab_size, conf.dim_embd)
self.pos_embd = nn.Embedding(conf.ctx_window, conf.dim_embd)
self.tfm_blocks = nn.ModuleList([MyTransformerBlock(conf) for _ in range(conf.n_tfm_block)])
self.lm_head = nn.Linear(conf.dim_embd, conf.vocab_size)
self.dropout = nn.Dropout(conf.dropout)
self.layer_norm = nn.LayerNorm(conf.dim_embd)
def forward(self, x_id: LongTensor) -> FloatTensor:
'''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
pos = torch.arange(x_id.shape[1], device=x_id.device)
tok_embd = self.tok_embd(x_id)
pos_embd = self.pos_embd(pos)
x_embd = tok_embd + pos_embd
x_embd = self.dropout(x_embd)
for tfm_block in self.tfm_blocks:
x_embd = tfm_block(x_embd)
x_embd = self.layer_norm(x_embd)
# note: using list [-1] to preserve the time dimension
logits = self.lm_head(x_embd[:, [-1], :])
return logits
def generate(self, x_id: LongTensor, max_new_tokens: int, temperature=1.0, top_k:int=1) -> LongTensor:
'''(padded) sequence of token ids -> sequence of token ids'''
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at ctx_window
x_id_cond = x_id if x_id.size(1) <= self.conf.ctx_window else x_id[:, -self.conf.ctx_window:]
logits = self.forward(x_id_cond)
logits = logits[:, -1, :] / temperature
if top_k > 1:
v, _ = torch.topk(logits, min(top_k, logits.shape[-1])) # top_k cannot exceed vocab_size
# logits less then top_k are set to -Inf, thus probabilities of those tokens become 0
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
new_tok_id = torch.multinomial(probs, num_samples=1)
x_id = torch.cat([x_id, new_tok_id], dim=1)
return x_id
接下來時我們的重頭戲,實現最主要的部件transformer block,也就是程式碼中的MyTransformerBlock
。這個部件也可以拆分為兩個子部件,其中核心的部件自然是自注意力機制self-attention,以及後面的非線性變化——一個簡單的MLP。資料會先後流過這兩個子部件,然後流向下一個block。當然,資料流入這兩個子部件前可以(應當)先layer norm一下。
class MyTransformerBlock(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.ln1 = nn.LayerNorm(conf.dim_embd)
self.attn = MyMultiHeadAttention(conf)
self.ln2 = nn.LayerNorm(conf.dim_embd)
self.mlp = MyMLP(conf)
def forward(self, x: FloatTensor) -> FloatTensor:
'''[batch_size, seq_len, dim_embd] -> [batch_size, seq_len, dim_embd]'''
x = x + self.attn(self.ln1(x)) # layer norm + attention + residual
x = x + self.mlp(self.ln2(x)) # layer norm + MLP + residual
return x
這其中的後者MLP模組可以實現成一個兩層的全連線前饋網路
class MyMLP(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.fc1 = nn.Linear(conf.dim_embd, conf.dim_embd * 4)
# here output dimension (dim_embd*4) is set arbitrary
self.gelu = nn.GELU()
self.fc2 = nn.Linear(conf.dim_embd * 4, conf.dim_embd)
self.dropout = nn.Dropout(conf.dropout)
def forward(self, x: FloatTensor) -> FloatTensor:
x = self.fc1(x)
x = self.gelu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
而最核心的注意力機制,我們也先做一個大致的設計:
class MyMultiHeadAttention(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
def forward(self, x: FloatTensor) -> FloatTensor:
pass
注意力機制又是什麼意思?在日常的語境裡,注意就是集中大部分資源(時間、心力)處理小部分重要的/相關的資訊,而只用小部分資源處理其他大部分不重要/不想關的資訊。 使用文件檢索做比喻,就是給定一個query,找到那些與這個query高度相關的文件K,然後集中資源處理這些文件的內容V。
在數學上,我們可以把上面說到的注意力機制設計成這麼一個函式:
這裡面\(Q\)是一批query矩陣,\(K\)是一批文件的矩陣,\(V\)是這些文件的內容。\(QK^{\top}/d_k\)可以理解為計算出每一個文件與給定query的相似度,使用\(Softmax\)是吧這些相似度壓縮到0-1之間,得到應該對每一個文件分配多少資源,最後乘以矩陣V就是根據注意力處理文件內容了。在使用GPT處理序列時,序列中的每一個token既是query,也是文件k和文件內容v。當我們處理某一個token時,序列中包括該token在內的所有token都是我們需要處理的內容,這也是自注意力機制名稱的由來。翻譯成程式碼就是:
class MyMultiHeadAttention(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
def atention(self, q, k, v: FloatTensor) -> FloatTensor:
d_k = k.shape[-1]
return F.softmax(q @ k.mT / math.sqrt(d_k), dim=-1) @ v
def forward(self, x: FloatTensor) -> FloatTensor:
q, k, v = self.make_QKV(x)
return self.attention(q, k, v)
在使用auto-regressive方式進行建模/生成的過程中,我們要確保每一個token只能注意到它前(左)邊的token,而不能注意到後(右)邊的token,而在程式設計中我們可以使用一個類似mask=[True,True,...,False,False]
的技巧將尚未建模/生成的後邊子序列的注意力分數設定為0(softmax前的無窮小)。另外,在實際使用的時候我們也可以dropout一下。那麼升級版的attention函式就變成了
class MyMultiHeadAttention(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
# causal mask to ensure that attention is only applied to the left in the input sequence
bias = torch.tril(torch.ones(conf.ctx_win, conf.ctx_win)).view(1, 1, conf.ctx_win, conf.ctx_win)
self.register_buffer("bias", bias)
self.attn_dropout = nn.Dropout(conf.dropout)
self.resid_dropout = nn.Dropout(conf.dropout)
def attention(self, q, k, v, mask) -> FloatTensor:
scores = q @ k.mT / math.sqrt(k.shape[-1])
# ensure that attention is only applied to the left
scores = scores.masked_fill(mask==False, float('-inf'))
attn = torch.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
attn = attn @ v
return attn
def forward(self, x: FloatTensor) -> FloatTensor:
q, k, v, mask = self.make_QKV(x)
seq_len = x.shape[1]
mask = self.bias[:, :, :seq_len, :seq_len]
y = self.attention(q, k, v)
y = self.resid_dropout(y)
return y
那麼,自注意力機制中的矩陣\(Q,K,V\)是怎麼得來的?在GPT中,他們都是把輸入向量\(x\)線性對映到三個不同的空間得到的,也就是透過學習出三個矩陣\(W_Q,W_K,W_Q\),然後用這仨矩陣乘以輸入向量\(x\)。處於編碼的方便(也是通常的做法),我們這裡設\(W_Q,W_K,W_Q\)的維度都是dim_embd
。
class MyMultiHeadAttention(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
self.W_q = nn.Linear(conf.dim_embd, conf.dim_embd)
self.W_k = nn.Linear(conf.dim_embd, conf.dim_embd)
self.W_v = nn.Linear(conf.dim_embd, conf.dim_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
bias = torch.tril(torch.ones(conf.ctx_win, conf.ctx_win)).view(1, 1, conf.ctx_win, conf.ctx_win)
self.register_buffer("bias", bias)
self.attn_dropout = nn.Dropout(conf.dropout)
self.resid_dropout = nn.Dropout(conf.dropout)
def attention(self, q, k, v, mask) -> FloatTensor:
scores = q @ k.mT / math.sqrt(k.shape[-1])
# ensure that attention is only applied to the left
scores = scores.masked_fill(mask==0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
attn = attn @ v
return attn
def make_QKV(self, x):
q = self.W_q(x)
k = self.W_k(x)
v = self.W_v(x)
return q, k, v
def forward(self, x: FloatTensor) -> FloatTensor:
q, k, v = self.make_QKV(x)
seq_len = x.shape[1]
mask = self.bias[:, :, :seq_len, :seq_len]
y = self.attention(q, k, v, mask)
y = self.resid_dropout(y)
return y
在Transformer中,我們不止使用一個注意力,而是多個注意力,直覺上這也是有多種注意角度,或者說注意一個序列的不同方面,最後把各個注意頭的輸出拼接起來再做一次線性變換(乘以一個矩陣\(W^O\)):
其中\(head_i = Attention(QW_i^Q,KW_i^K,VW_i^V)\)。
將這些想法整合起來,最後的多頭注意力機制的程式碼就是
class MyMultiHeadAttention(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
self.W_q = nn.Linear(conf.dim_embd, conf.dim_embd)
self.W_k = nn.Linear(conf.dim_embd, conf.dim_embd)
self.W_v = nn.Linear(conf.dim_embd, conf.dim_embd)
self.W_out = nn.Linear(conf.dim_embd, conf.dim_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
bias = torch.tril(torch.ones(conf.ctx_win, conf.ctx_win)).view(1, 1, conf.ctx_win, conf.ctx_win)
self.register_buffer("bias", bias)
self.attn_dropout = nn.Dropout(conf.dropout)
self.resid_dropout = nn.Dropout(conf.dropout)
def attention(self, q, k, v, mask) -> FloatTensor:
scores = q @ k.mT / math.sqrt(k.shape[-1])
# ensure that attention is only applied to the left
scores = scores.masked_fill(mask==0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
attn = attn @ v
return attn
def make_QKV(self, x, batch_size, seq_len, n_attn_head, dim_embd):
q = self.W_q(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
k = self.W_k(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
v = self.W_v(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
return q, k, v
def forward(self, x: FloatTensor) -> FloatTensor:
batch_size, seq_len, dim_embd = x.shape
n_attn_head, dim_embd = self.conf.n_attn_head, self.conf.dim_embd
q, k, v = self.make_QKV(x, batch_size, seq_len, n_attn_head, dim_embd)
mask = self.bias[:, :, :seq_len, :seq_len]
y = self.attention(q, k, v, mask)
y = y.transpose(1, 2).contiguous().view(batch_size, seq_len, dim_embd) # re-assemble all head outputs side by side
y = self.W_out(y) # output projection
y = self.resid_dropout(y)
return y
至此,我們一步一步手寫出了一個GPT了。完整程式碼如下
'''
Hand-made GPT, adapted from [nanoGPT](https://github.com/karpathy/nanoGPT/)
'''
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, FloatTensor
@dataclass
class MyGPTConfig:
vocab_size: int = 50257
ctx_win: int = 1024
dim_embd: int = 768
n_attn_head: int = 12
n_tfm_block: int = 12
dropout: float = 0.0
use_bias: bool = True
class MyMLP(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.fc1 = nn.Linear(conf.dim_embd, conf.dim_embd * 4)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(conf.dim_embd * 4, conf.dim_embd)
self.dropout = nn.Dropout(conf.dropout)
def forward(self, x: FloatTensor) -> FloatTensor:
x = self.fc1(x)
x = self.gelu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
class MyMultiHeadAttention(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
self.W_q = nn.Linear(conf.dim_embd, conf.dim_embd)
self.W_k = nn.Linear(conf.dim_embd, conf.dim_embd)
self.W_v = nn.Linear(conf.dim_embd, conf.dim_embd)
self.W_out = nn.Linear(conf.dim_embd, conf.dim_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
bias = torch.tril(torch.ones(conf.ctx_win, conf.ctx_win)).view(1, 1, conf.ctx_win, conf.ctx_win)
self.register_buffer("bias", bias)
self.attn_dropout = nn.Dropout(conf.dropout)
self.resid_dropout = nn.Dropout(conf.dropout)
def attention(self, q, k, v, mask) -> FloatTensor:
scores = q @ k.mT / math.sqrt(k.shape[-1])
# ensure that attention is only applied to the left
scores = scores.masked_fill(mask==0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
attn = self.attn_dropout(attn)
attn = attn @ v
return attn
def make_QKV(self, x, batch_size, seq_len, n_attn_head, dim_embd):
q = self.W_q(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
k = self.W_k(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
v = self.W_v(x).view(batch_size, seq_len, n_attn_head, dim_embd // n_attn_head).transpose(1, 2)
return q, k, v
def forward(self, x: FloatTensor) -> FloatTensor:
batch_size, seq_len, dim_embd = x.shape
n_attn_head, dim_embd = self.conf.n_attn_head, self.conf.dim_embd
q, k, v = self.make_QKV(x, batch_size, seq_len, n_attn_head, dim_embd)
mask = self.bias[:, :, :seq_len, :seq_len]
y = self.attention(q, k, v, mask)
y = y.transpose(1, 2).contiguous().view(batch_size, seq_len, dim_embd) # re-assemble all head outputs side by side
y = self.W_out(y) # output projection
y = self.resid_dropout(y)
return y
class MyTransformerBlock(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.ln1 = nn.LayerNorm(conf.dim_embd)
self.attn = MyMultiHeadAttention(conf)
self.ln2 = nn.LayerNorm(conf.dim_embd)
self.mlp = MyMLP(conf)
def forward(self, x: FloatTensor) -> FloatTensor:
'''[batch_size, seq_len, dim_embd] -> [batch_size, seq_len, dim_embd]'''
x = x + self.attn(self.ln1(x)) # layer norm + attention + residual
x = x + self.mlp(self.ln2(x)) # layer norm + MLP + residual
return x
class MyGPT(nn.Module):
def __init__(self, conf: MyGPTConfig) -> None:
super().__init__()
self.conf = conf
self.tok_embd = nn.Embedding(conf.vocab_size, conf.dim_embd)
self.pos_embd = nn.Embedding(conf.ctx_win, conf.dim_embd)
self.tfm_blocks = nn.ModuleList([MyTransformerBlock(conf) for _ in range(conf.n_tfm_block)])
self.lm_head = nn.Linear(conf.dim_embd, conf.vocab_size)
self.dropout = nn.Dropout(conf.dropout)
self.layer_norm = nn.LayerNorm(conf.dim_embd)
def forward(self, x_id: LongTensor) -> FloatTensor:
'''(padded) sequence of token ids -> logits of shape [batch_size, 1, vocab_size]'''
pos = torch.arange(x_id.shape[1], device=x_id.device)
tok_embd = self.tok_embd(x_id)
pos_embd = self.pos_embd(pos)
x_embd = tok_embd + pos_embd
x_embd = self.dropout(x_embd)
for tfm_block in self.tfm_blocks:
x_embd = tfm_block(x_embd)
x_embd = self.layer_norm(x_embd)
# note: using list [-1] to preserve the time dimension
logits = self.lm_head(x_embd[:, [-1], :])
return logits
def generate(self, x_id: LongTensor, max_new_tokens: int, temperature=1.0, top_k:int=1) -> LongTensor:
'''(padded) sequence of token ids -> sequence of token ids'''
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at ctx_win
x_id_cond = x_id if x_id.size(1) <= self.conf.ctx_win else x_id[:, -self.conf.ctx_win:]
logits = self.forward(x_id_cond)
logits = logits[:, -1, :] / temperature
if top_k > 1:
v, _ = torch.topk(logits, min(top_k, logits.shape[-1])) # top_k cannot exceed vocab_size
logits[logits < v[:, [-1]]] = -float('Inf') # logits less then top_k are set to -Inf, thus probabilites of those tokens become 0
probs = F.softmax(logits, dim=-1)
new_tok_id = torch.multinomial(probs, num_samples=1)
x_id = torch.cat([x_id, new_tok_id], dim=1)
return x_id
if __name__ == '__main__':
conf = MyGPTConfig()
model = MyGPT(conf)
print(model)
inp = torch.LongTensor([
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
])
print(model(inp))
print(model.generate(inp, max_new_tokens=3))