LLaMA 3 原始碼解讀-大語言模型5

vanilla阿草發表於2024-05-07

本來不是很想寫這一篇,因為網上的文章真的爛大街了,我寫的真的很有可能沒別人寫得好。但是想了想,建立這個部落格就是想透過對外輸出知識的方式來提高自身水平,而不是說我每篇都能寫得有多好多好然後吸引別人來看。那作為對整個合集內容的完善,這篇部落格會解析現在最火的LLaMA3的模型架構,搞清楚現在的LLM都是啥樣的。

事先說明,LlaMA 3 相較於LLaMA 2 在網路架構上沒有改進。用知乎網友的話說,“llama3的釋出,更強調了資料工程的重要:模型架構不變,更多的資料量和更高資料質量能夠帶來明顯模型效果提升”。但是仔細看看一個LLM的原始碼,對於我這種初學者,還是非常有必要的。

https://zhuanlan.zhihu.com/p/693428105

還有就是,這個部落格解析的原始碼是d6e09315954d1a547bf45e37269978c049e73d33這個版本的。如果後面Meta更新的部分程式碼導致和這篇部落格內容對不上,你可以先翻閱這個版本的原始碼。如果還有什麼解決不了的,可以在這篇部落格下面給我留言,我們共同學習共同進步。

Llama類:起步

Llama.build與如何看原始碼

我們透過llama3的ReadMe,找到了這個demo,demo透過

from llama import Dialog, Llama

generator = (ckpt_dir, tokenizer_path, max_seq_len, max_batch_size)
results = generator.chat_completion(dialogs, max_gen_len, temperature, top_p)

完成對話。它先呼叫了 Llama.build,再對返回的物件呼叫了generator.chat_completion完成對話的功能;匯入的庫是llama。 進而關注到repo下面的llama資料夾,所以會先看一看資料夾下面的__init__.py

from .generation import Llama
from .model import ModelArgs, Transformer
from .tokenizer import Dialog, Tokenizer

所以demo呼叫的 Llama.build.generation裡面。順藤摸瓜找到:

class Llama:
    @staticmethod
    def build(
        ckpt_dir: str,
        tokenizer_path: str,
        max_seq_len: int,
        max_batch_size: int,
        model_parallel_size: Optional[int] = None,
        seed: int = 1,
    ) -> "Llama":
        """
        Build a Llama instance by initializing and loading a model checkpoint.

        Args:
            ckpt_dir (str): 模型檢查點檔案的路徑
            tokenizer_path (str): 模型tokenizer檔案路徑.
            max_seq_len (int): Maximum sequence length for input text.
            max_batch_size (int): Maximum batch size for inference.
            model_parallel_size (Optional[int], optional): Number of model parallel processes.
                If not provided, it's determined from the environment. Defaults to None.

        Returns:
            Llama: An instance of the Llama class with the loaded model and tokenizer.
        """
        # 這裡首先是一些模型並行設定
        if not torch.distributed.is_initialized():
            torch.distributed.init_process_group("nccl")
        if not model_parallel_is_initialized():
            if model_parallel_size is None:
                model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
            initialize_model_parallel(model_parallel_size)

        # 多機訓練/推理一個模型的話,每個機器都會有個rank。這裡就是配置這個rank的。
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        # 隨機種子
        torch.manual_seed(seed)
        # 設定輸出只在一臺裝置上進行
        if local_rank > 0:
            sys.stdout = open(os.devnull, "w")

        # 終於到載入模型相關的程式碼了
        start_time = time.time()
        checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
        # 檢查模型檢查點檔案的數量是否合乎要求
        assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
        assert model_parallel_size == len(
            checkpoints
        ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"

        # 載入模型。多機執行時`get_model_parallel_rank()`返回的結果不一樣,所以不需要寫for迴圈。這裡的思想有cuda程式設計那味了
        ckpt_path = checkpoints[get_model_parallel_rank()]
        checkpoint = torch.load(ckpt_path, map_location="cpu")

        # TODO: 讀取`params.json`並透過類`ModelArgs`載入進變數`model_args`。這個類我們待會講
        with open(Path(ckpt_dir) / "params.json", "r") as f:
            params = json.loads(f.read())
        model_args: ModelArgs = ModelArgs(
            max_seq_len=max_seq_len,
            max_batch_size=max_batch_size,
            **params,
        )

        # TODO: 載入Tokenizer。Tokenizer我們待會講
        tokenizer = Tokenizer(model_path=tokenizer_path)
        assert model_args.vocab_size == tokenizer.n_words

        # 半精度相關
        if torch.cuda.is_bf16_supported():
            torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
        else:
            torch.set_default_tensor_type(torch.cuda.HalfTensor)
        
        # TODO: 是的,llama3的模型主體就是這裡的Transformer類。直接model.load_state_dict就能載入好權重。這個也待會講
        model = Transformer(model_args)
        model.load_state_dict(checkpoint, strict=False)
        print(f"Loaded in {time.time() - start_time:.2f} seconds")

        # TODO: 到這裡其實啥都載入完了,這裡返回了個Llama類。
        return Llama(model, tokenizer)

這段程式碼看下來邏輯很清晰,就是給我們留下了幾個TODO,這些我們都會講到。

ModelArgs

我們首先看到ModelArgs類,這個類只用於儲存一些引數,@dataclass裝飾器就已經說明了一切:

@dataclass
class ModelArgs:
    dim: int = 4096  # 模型維度
    n_layers: int = 32  # 層數
    n_heads: int = 32  # 頭數
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1  # 詞彙表大小
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000
    max_batch_size: int = 32
    max_seq_len: int = 2048  # 序列長度

llama.__init__()

最後這一句return Llama(model, tokenizer),它實際上會呼叫Llama.__init__(),程式碼如下:

from llama.tokenizer import ChatFormat, Dialog, Message, Tokenizer

def __init__(self, model: Transformer, tokenizer: Tokenizer):
    self.model = model
    self.tokenizer = tokenizer
    # TODO: ChatFormat類解析
    self.formatter = ChatFormat(tokenizer)

是的,簡單賦值就結束了。formatter這裡用到的ChatFormat類我們一會隨tokenizer一起解析。

Transformer類:Llama3模型架構詳解

這一部分應該是被人關心得最多的部分了。

Transformer.__init__()

首先看模型初始化,這裡就是設定了一堆類的屬性。我們直接上程式碼,解析見程式碼註釋:

from fairscale.nn.model_parallel.layers import (
    ColumnParallelLinear,
    RowParallelLinear,
    VocabParallelEmbedding,
) # FairScale庫的模組都是用於實現模型並行化的,不需要深究

class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        # VocabParallelEmbedding類匯入自fairscale,功能同`torch.nn.embedding`
        self.tok_embeddings = VocabParallelEmbedding(
            params.vocab_size, params.dim, init_method=lambda x: x
        )

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            # TODO: TransformerBlock
            self.layers.append(TransformerBlock(layer_id, params))

        # TODO: RMSNorm
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)

        # ColumnParallelLinear 相當於 `torch.nn.linear`
        self.output = ColumnParallelLinear(
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        )

        # TODO: precompute_freqs_cis
        self.freqs_cis = precompute_freqs_cis(
            params.dim // params.n_heads,
            params.max_seq_len * 2,
            params.rope_theta,
        )

RMSNorm

RMSNorm是均值為0的LayerNorm:

\[\begin{equation} \bar{a}_i=\frac{a_i}{R M S(a)} g_i \text{ where } R M S(a)=\sqrt{\frac{1}{n} \sum_{i=1}^n a_i{ }^2} \end{equation} \]

注:layerNorm為

\[\begin{equation} \bar{a}_i=\frac{a_i - \mu }{ \sigma } g_i \text{ where } \mu=\frac{1}{n} \sum_{i=1}^n {a_i } \text{ and } \sigma=\sqrt{\frac{1}{n} \sum_{i=1}^n {(a_i - \mu)}^2} \end{equation}\]

用程式碼實現出來是這個樣子的:

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # 初始化為1的可學習引數

    def _norm(self, x):
        # torch.rsqrt: 平方根的倒數,這裡用於計算標準差的倒數
        # x.pow(2).mean(-1, keepdim=True): 沿著倒數第一維計算平方並求平均
        #    a_i * 元素平方的均值取平方根後再取倒數 + 無窮小量
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

作者認為這種模式在簡化了Layer Norm的同時,可以在各個模型上減少約 7%∼64% 的計算時間

旋轉位置編碼RoPE

該部分內容參考了 蘇劍成的部落格。蘇劍成是RoPE的發明者。

旋轉位置編碼透過絕對位置編碼的方式實現相對位置編碼。假設透過下述運算來給 \(q,k\) 新增絕對位置資訊:

分別為 \(q,k\) 設計操作 \(\boldsymbol{f}(\cdot, m),\boldsymbol{f}(\cdot, n)\) ,使得經過該操作後,\(\tilde{\boldsymbol{q}}_m,\tilde{\boldsymbol{k}}_n\) 就帶有了位置 \(m,n\) 的絕對位置資訊。Attention的核心運算是內積,所以我們希望的內積的結果帶有相對位置資訊,因此假設存在恆等關係:

\[\begin{equation}\langle\boldsymbol{f}(\boldsymbol{q}, m), \boldsymbol{f}(\boldsymbol{k}, n)\rangle = g(\boldsymbol{q},\boldsymbol{k},m-n)\end{equation} \]

解得:

\[\begin{equation} \boldsymbol{f}(\boldsymbol{q}, m) = R_f (\boldsymbol{q}, m)e^{\text{i}\Theta_f(\boldsymbol{q}, m)} = \Vert q\Vert e^{\text{i}(\Theta(\boldsymbol{q}) + m\theta)} = \boldsymbol{q} e^{\text{i}m\theta}\end{equation} \]

可以寫成:

\[\begin{equation} \boldsymbol{f}(\boldsymbol{q}, m) =\begin{pmatrix}\cos m\theta & -\sin m\theta\\ \sin m\theta & \cos m\theta\end{pmatrix} \begin{pmatrix}q_0 \\ q_1\end{pmatrix}\end{equation} \]

由於內積滿足線性疊加性,因此任意偶數維的RoPE,我們都可以表示為二維情形的拼接,即:

\[\begin{equation}\scriptsize{\underbrace{\begin{pmatrix} \cos m\theta_0 & -\sin m\theta_0 & 0 & 0 & \cdots & 0 & 0 \\ \sin m\theta_0 & \cos m\theta_0 & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m\theta_1 & -\sin m\theta_1 & \cdots & 0 & 0 \\ 0 & 0 & \sin m\theta_1 & \cos m\theta_1 & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m\theta_{d/2-1} & -\sin m\theta_{d/2-1} \\ 0 & 0 & 0 & 0 & \cdots & \sin m\theta_{d/2-1} & \cos m\theta_{d/2-1} \\ \end{pmatrix}}_{\boldsymbol{\mathcal{R}}_m} \begin{pmatrix}q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1}\end{pmatrix}}\end{equation} \]

我們便可以透過以下方式實現RoPE:

\[\begin{equation}\begin{pmatrix}q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{pmatrix}\otimes\begin{pmatrix}\cos m\theta_0 \\ \cos m\theta_0 \\ \cos m\theta_1 \\ \cos m\theta_1 \\ \vdots \\ \cos m\theta_{d/2-1} \\ \cos m\theta_{d/2-1} \end{pmatrix} + \begin{pmatrix}-q_1 \\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2} \end{pmatrix}\otimes\begin{pmatrix}\sin m\theta_0 \\ \sin m\theta_0 \\ \sin m\theta_1 \\ \sin m\theta_1 \\ \vdots \\ \sin m\theta_{d/2-1} \\ \sin m\theta_{d/2-1} \end{pmatrix}\end{equation} \]

precompute_freqs_cis

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    # 計算詞向量元素兩兩分組以後,每組元素對應的旋轉角度 
    # torch.arange(0, dim, 2): 生成 [0,2,4...126]
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)  # t = [0,....end]
    # torch.outer: torch.outer(a, b) = a^T * b
    freqs = torch.outer(t, freqs)  # freqs.shape = (t.len(),freqs.len()) #shape (end,dim//2)

    # 根據角座標生成複數向量
    # torch.polar(abs,angle): abs*cos(angle) + abs*sin(angle)*j
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # freqs_cis.shape  = (end,dim//2)
    return freqs_cis

reshape_for_broadcast

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    # ndim為x的維度數, 此時應該為4
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    # (1, x.shape[1], 1, x.shape[-1])
    return freqs_cis.view(*shape)

apply_rotary_emb

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """將xq和xk的最後一個維度進行復數運算,得到新的xq和xk"""
    # xq.shape = [bsz, seqlen, self.n_local_heads, self.head_dim]
    # xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2 , 2]
    # torch.view_as_complex用於將二維向量轉換為複數域 torch.view_as_complex即([x,y]) -> (x+yj)
    # 所以經過view_as_complex變換後xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # freqs_cis.shape = (1,x.shape[1],1,x.shape[-1])
    
    # xq_ 與freqs_cis廣播哈達瑪積
    # [bsz, seqlen, self.n_local_heads, self.head_dim//2] * [1,seqlen,1,self.head_dim//2]
    # torch.view_as_real用於將複數再轉換回實數向量, 再經過flatten展平第4個維度 
    # [bsz, seqlen, self.n_local_heads, self.head_dim//2] ->[bsz, seqlen, self.n_local_heads, self.head_dim//2,2 ] ->[bsz, seqlen, self.n_local_heads, self.head_dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

TransformerBlock

這個類比較簡單,只是一個transformer block。


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        """初始化函式主要就是定義了transformer block的各個元件,包括自注意力機制和前饋神經網路。"""
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads

        # TODO: Attention
        self.attention = Attention(args)

        # TODO: FeedForward
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of,  ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        """這個函式是transformer block的前向傳播函式,輸入是x,start_pos,freqs_cis,mask,輸出是out"""
        # 這個函式的實現比較簡單,首先對輸入張量x進行自注意力機制計算,然後對計算結果進行殘差連線和歸一化,再透過前饋神經網路計算,最後再次進行殘差連線和歸一化。
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

Attention

為了實現Group Query Attention,這裡用到了一個函式repeat_kv,它的作用是將key和value的head維度重複n_rep次,以匹配query的head數。repeat_kv函式使用 expand 方法將輸入張量在第四個維度上擴充套件 n_rep 次,並使用 reshape 方法將其調整為適當的形狀

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )


​# 精簡版Attention
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.wq = Linear(...)
        self.wk = Linear(...)
        self.wv = Linear(...)
        
        self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)
​
    def forward(self, x: torch.Tensor):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        # attention 操作之前,應用旋轉位置編碼
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
        #...
        # 進行後續Attention計算
        scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
        scores = F.softmax(scores.float(), dim=-1)
        output = torch.matmul(scores, xv)  # (batch_size, seq_len, dim)

FeedForward類與SwiGLU啟用函式

FeedForward類實現的是:

\[\begin{equation} FFN_{swiGLU}(x, W, V, W_2)=(Swish1 (xW) \bigotimes xV)W_2 \end{equation}\]


使用的啟用函式是SwiGLU,這裡有:

\[\begin{equation}SwiGLU=Swish(Wx + b) \bigotimes (Vx + c)\end{equation} \]

\[\begin{equation}Swish(x) = x \times sigmoid(\beta x)\end{equation} \]

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):  # 我們不妨跳過這個函式,太無聊了
        ...

    def forward(self, x):
        # w2 * silu(w1 * x) * w3
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

以下內容參考知乎

\(\beta = 1\)\(swish(x)\)就是$silu(x) $

\[\begin{equation}silu(x) = x \times sigmoid(x) = \frac{x}{1+e^{-x}}\end{equation} \]

函式影像如下:

Transformer.forward()

前向傳播就是我們熟悉的 Transformer 前向傳播了。

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape  # 批大小和序列長度
        h = self.tok_embeddings(tokens)  # 詞嵌入層進行嵌入,得到表示輸入序列的張量h
        self.freqs_cis = self.freqs_cis.to(h.device)  # 將頻率轉換為與輸入張量相同的裝置
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]  # 從預計算的頻率張量中提取頻率

        mask = None  # 用於在自注意力機制中遮蔽不必要的位置的mask
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)  # 建立一個形狀為(seqlen, seqlen)的張量,填充為負無窮
            mask = torch.triu(mask, diagonal=1)  # 上三角矩陣
            mask = torch.hstack(
                [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
            ).type_as(h)  # 將mask張量與全零張量水平拼接,以適應輸入張量h的維度

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)  # 逐層進行transformer計算
        h = self.norm(h)  # 對輸出張量進行歸一化
        output = self.output(h).float()  # 輸出層進行線性變換
        return output

Tokenizer

Tokenizer類主要呼叫tiktoken庫,沒啥好講的。這裡的函式大多是前面定義了一大堆東西,但是翻閱具體業務的時候發現其實還是在調庫。

class Tokenizer:
    """
    Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
    """

    special_tokens: Dict[str, int]
    num_reserved_special_tokens = 256
    pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501

    def __init__(self, model_path: str):
        """
        Initializes the Tokenizer with a Tiktoken model.

        Args:
            model_path (str): The path to the Tiktoken model file.
        """
        assert os.path.isfile(model_path), model_path
        mergeable_ranks = load_tiktoken_bpe(model_path)
        num_base_tokens = len(mergeable_ranks)
        special_tokens = [
            "<|begin_of_text|>", "<|end_of_text|>",
            "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>",  # end of turn
            "<|reserved_special_token_0|>", "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>", "<|reserved_special_token_3|>", "<|reserved_special_token_4|>",
        ] + [
            f"<|reserved_special_token_{i}|>"
            for i in range(5, self.num_reserved_special_tokens - 5)
        ]
        self.special_tokens = {
            token: num_base_tokens + i for i, token in enumerate(special_tokens)
        }
        self.model = tiktoken.Encoding(
            name=Path(model_path).name, pat_str=self.pat_str,
            mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens,
        )
        self.n_words: int = self.model.n_vocab
        # BOS / EOS token IDs
        self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
        self.eos_id: int = self.special_tokens["<|end_of_text|>"]
        self.pad_id: int = -1
        self.stop_tokens = {
            self.special_tokens["<|end_of_text|>"],
            self.special_tokens["<|eot_id|>"],
        }

    def encode(
        self, s: str, *, bos: bool, eos: bool,
        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
        disallowed_special: Union[Literal["all"], Collection[str]] = (),
    ) -> List[int]:
        """
        Encodes a string into a list of token IDs.

        Args:
            s (str): The input string to be encoded.
            bos (bool): Whether to prepend the beginning-of-sequence token.
            eos (bool): Whether to append the end-of-sequence token.
            allowed_tokens ("all"|set[str]): allowed special tokens in string
            disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string

        Returns:
            list[int]: A list of token IDs.

        By default, setting disallowed_special=() encodes a string by ignoring
        special tokens. Specifically:
        - Setting `disallowed_special` to () will cause all text corresponding
          to special tokens to be encoded as natural text (insteading of raising
          an error).
        - Setting `allowed_special` to "all" will treat all text corresponding
          to special tokens to be encoded as special tokens.
        """
        assert type(s) is str

        # The tiktoken tokenizer can handle <=400k chars without pyo3_runtime.PanicException.
        TIKTOKEN_MAX_ENCODE_CHARS = 400_000

        # Here we iterate over subsequences and split if we exceed the limit of max consecutive non-whitespace or whitespace characters.
        MAX_NO_WHITESPACES_CHARS = 25_000

        substrs = (
            substr
            for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
            for substr in self._split_whitespaces_or_nonwhitespaces(
                s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
            )
        )
        t: List[int] = []
        for substr in substrs:
            t.extend(
                # 呼叫在這裡
                self.model.encode(
                    substr,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            )
        if bos:
            t.insert(0, self.bos_id)
        if eos:
            t.append(self.eos_id)
        return t

    def decode(self, t: Sequence[int]) -> str:
        """
        Decodes a list of token IDs into a string.

        Args:
            t (List[int]): The list of token IDs to be decoded.

        Returns:
            str: The decoded string.
        """
        # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
        return self.model.decode(cast(List[int], t))

    @staticmethod
    def _split_whitespaces_or_nonwhitespaces(
        s: str, max_consecutive_slice_len: int
    ) -> Iterator[str]:
        """
        Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
        consecutive whitespaces or consecutive non-whitespaces.
        """
        current_slice_len = 0
        current_slice_is_space = s[0].isspace() if len(s) > 0 else False
        slice_start = 0

        for i in range(len(s)):
            is_now_space = s[i].isspace()

            if current_slice_is_space ^ is_now_space:
                current_slice_len = 1
                current_slice_is_space = is_now_space
            else:
                current_slice_len += 1
                if current_slice_len > max_consecutive_slice_len:
                    yield s[slice_start:i]
                    slice_start = i
                    current_slice_len = 1
        yield s[slice_start:]

ChatFormat

ChatFormat類藉助Tokenizer類,對Tokenizer進行了進一步包裝,提供了encode_headerencode_messageencode_dialog_prompt三種encode方式。

class ChatFormat:
    def __init__(self, tokenizer: Tokenizer):
        self.tokenizer = tokenizer

    def encode_header(self, message: Message) -> List[int]:
        tokens = []
        tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
        tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
        tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
        tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
        return tokens

    def encode_message(self, message: Message) -> List[int]:
        tokens = self.encode_header(message)
        tokens.extend(
            self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
        )
        tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
        return tokens

    def encode_dialog_prompt(self, dialog: Dialog) -> List[int]:
        tokens = []
        tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
        for message in dialog:
            tokens.extend(self.encode_message(message))
        # Add the start of an assistant message for the model to complete.
        tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
        return tokens

總結

以上就是全部的原始碼解讀。如有疑問請留言。

相關文章