將 LLMs 精調至 1.58 位元: 使極端量化變簡單

HuggingFace發表於2024-09-29

隨著大語言模型 (LLMs) 規模和複雜性的增長,尋找減少它們的計算和能耗的方法已成為一個關鍵挑戰。一種流行的解決方案是量化,其中引數的精度從標準的 16 位浮點 (FP16) 或 32 位浮點 (FP32) 降低到 8 位或 4 位等低位格式。雖然這種方法顯著減少了記憶體使用量並加快了計算速度,但往往以準確性為代價。過度降低精度可能導致模型丟失關鍵資訊,從而導致效能下降。

BitNet 是一種特殊的 transformers 架構,它用僅三個值: (-1, 0, 1) 表示每個引數,提供了每個引數僅為 1.58 $ (log_2(3)) $ 位元的極端量化。然而,這需要從頭開始訓練一個模型。雖然結果令人印象深刻,但並非每個人都有預算來進行大語言模型的預訓練。為了克服這一限制,我們探索了一些技巧,允許將現有模型精調至 1.58 位元!繼續閱讀以瞭解更多!

目錄

  • 簡介
  • 更深入地瞭解什麼是 BitNet
  • 1.58 位元的預訓練結果
  • 1.58 位元的微調
  • 使用的核心和測試標準
  • 結論
  • 致謝
  • 更多資源

簡介

BitNet 是由微軟研究院提出的一種模型架構,其採用極端量化的方式,用僅三個值 -1、0 和 1 來表示每個引數。這導致模型每個引數僅使用 1.58 位元,顯著降低了計算和記憶體需求。

該架構在執行矩陣乘法時使用 INT8 加法計算,這與以 Llama 為例的傳統 LLM 架構的 FP16 乘加操作完全不同。

BitNet b1.58 的新計算正規化 (出處: BitNet 論文 https://arxiv.org/abs/2402.17764)

BitNet b1.58 的新計算正規化 (出處: BitNet 論文 https://arxiv.org/abs/2402.17764)

這種方法在理論上降低能耗,與 Llama 基準相比,BitNet b1.58 在矩陣乘法方面節省了 71.4 倍的計算能耗。

BitNet b1.58 與 Llama 的能耗對比 (出處: BitNet 論文 https://arxiv.org/abs/2402.17764)

BitNet b1.58 與 Llama 的能耗對比 (出處: BitNet 論文 https://arxiv.org/abs/2402.17764)

我們成功地使用 BitNet 架構對 Llama3 8B model 模型進行了精調,在下游任務中取得了良好的效能。我們開發的 8B 模型由 HF1BitLLM 組織釋出。其中兩個模型在 10B 的 token 上進行了不同的訓練設定的微調,而第三個模型在 100B 的 token 上進行了微調。值得注意的是,我們的模型在 MMLU 基準測試中超越了 Llama 1 7B 模型。

如何在 Transformers 中使用

為了將 BitNet 架構整合到 Transformers 中,我們引入了一種名為 “bitnet” 的新量化方法 (PR)。該方法涉及將標準的 Linear 層替換為專門設計用於 BitNet 架構的 BitLinear 層,其實現了相應的動態的啟用量化、權重解包和矩陣乘法的操作。

在 Transformers 中載入和測試模型非常簡單,API 沒有任何更改:

model = AutoModelForCausalLM.from_pretrained(
    "HF1BitLLM/Llama3-8B-1.58-100B-tokens",
    device_map="cuda",
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

input_text = "Daniel went back to the the the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:"

input_ids = tokenizer.encode(input_text, return_tensors="pt").cuda()
output = model.generate(input_ids, max_new_tokens=10)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)

透過這段程式碼,一切都直接在幕後完美地完成了,因此無需擔心額外的複雜性,您只需要做的只是安裝最新版本的 transformers。

要快速測試模型,請檢視這個 notebook

更深入地瞭解什麼是 BitNet

BitNet 在多頭注意力和前饋網路中替換了傳統的 Linear 層,使用了稱為 BitLinear 的特殊層,這些層使用三值精度 (甚至在初始版本中使用二值精度)。在這個專案中,我們使用的 BitLinear 層對權重使用三值精度 (取值為 -1、0 和 1),並將啟用量化為 8 位精度。我們在訓練和推理中使用不同的 BitLinear 實現,接下來的部分將會介紹。

在三值精度訓練中的主要障礙是權重值被離散化 (透過 round() 函式),因此不可微分。BitLinear 透過一個巧妙的技巧解決了這個問題: STE (Straight Through Estimator)。STE 允許梯度透過不可微分的取整操作,透過將其梯度近似為 1 (將 round() 視為等同於恆等函式) 來實現。另一種觀點是,STE 讓梯度透過取整步驟,好像取整從未發生過一樣,從而使用標準基於梯度的最佳化技術來更新權重。

使用 BitLienar 的 BitNet 模型架構 (出處: BitNet 論文 https://arxiv.org/pdf/2310.11453)

使用 BitLienar 的 BitNet 模型架構 (出處: BitNet 論文 https://arxiv.org/pdf/2310.11453)

訓練

我們在完整精度下進行訓練,但在訓練過程中將權重量化為三值,使用 per-tensor 的對稱量化。首先,我們計算權重矩陣的絕對值的平均值,並將其用作 scale。然後,我們將權重除以 scale,對值進行取整,將其限制在 -1 和 1 的區間內,最後將權重其反量化回完整精度。

\[scale_w = \frac{1}{\frac{1}{nm} \sum_{ij} |W_{ij}|} \]

\[W_q = \text{clamp}_{[-1,1]}(\text{round}(W*scale)) \]

\[W_{dequantized} = W_q*scale_w \]

啟用然後被量化為指定的位元寬度 (在我們的情況下是 8 位),使用 per-token 的最大絕對值量化 (要了解量化方法的全面介紹,請檢視這篇 post)。這涉及將啟用縮放到 [-128, 127] 的範圍以適應 8 位位元寬度。量化公式如下:

\[scale_x = \frac{127}{|X|_{\text{max}, , \text{dim}=-1}} \]

\[X_q = \text{clamp}_{[-128,127]}(\text{round}(X*scale)) \]

\[X_{dequantized} = X_q * scale_x \]

為了使這些公式更加清晰,下面是一些使用 3x3 的矩陣的權重和啟用量化的例子:


例子 1: 權重矩陣量化

假設權重矩陣 $ W $為:

\[W = \begin{bmatrix} 0.8 & -0.5 & 1.2 \\ -1.5 & 0.4 & -0.9 \\ 1.3 & -0.7 & 0.2 \end{bmatrix} \]

第一步: 計算權重的 scale

使用公式:
k

\[scale_w = \frac{1}{\frac{1}{nm} \sum_{ij} |W_{ij}|} \]

我們計算 $ W $ 啟用值的平均值:

\[\frac{1}{nm} \sum_{ij} |W_{ij}| = \frac{1}{9}(0.8 + 0.5 + 1.2 + 1.5 + 0.4 + 0.9 + 1.3 + 0.7 + 0.2) = \frac{1}{9}(7.5) = 0.8333 \]

現在得到的 scale 為:

\[scale_w = \frac{1}{0.8333} \approx 1.2 \]

第二步: 量化權重矩陣

使用公式:

\[W_q = \text{clamp}_{[-1, 1]}(\text{round}(W \times scale_w)) \]

我們首先將權重縮放 $ scale_w \approx 1.2 $ 倍:

\[W \times scale_w = \begin{bmatrix} 0.8 \times 1.2 & -0.5 \times 1.2 & 1.2 \times 1.2 \\ -1.5 \times 1.2 & 0.4 \times 1.2 & -0.9 \times 1.2 \\ 1.3 \times 1.2 & -0.7 \times 1.2 & 0.2 \times 1.2 \end{bmatrix} \begin{bmatrix} 0.96 & -0.6 & 1.44 \\ -1.8 & 0.48 & -1.08 \\ 1.56 & -0.84 & 0.24 \end{bmatrix} \]

然後我們將其取整並截斷到 $ [-1, 1] $ 的區間內:

\[W_q = \begin{bmatrix} 1 & -1 & 1 \\ -1 & 0 & -1 \\ 1 & -1 & 0 \end{bmatrix} \]

第三步: 反量化權重

最後我們反量化該權重:

\[W_{dequantized} = W_q \times scale_w \]

使用 scale_w 將權重恢復到原來的範圍,我們可以得到:

\[W_{dequantized} = \begin{bmatrix} 1 \times 1.2 & -1 \times 1.2 & 1 \times 1.2 \\ -1 \times 1.2 & 0 \times 1.2 & -1 \times 1.2 \\ 1 \times 1.2 & -1 \times 1.2 & 0 \times 1.2 \end{bmatrix} \begin{bmatrix} 1.2 & -1.2 & 1.2 \\ -1.2 & 0 & -1.2 \\ 1.2 & -1.2 & 0 \end{bmatrix} \]

例子 2: 啟用矩陣的量化

假設啟用矩陣 $ X $ 為:

\[X = \begin{bmatrix} 1.0 & -0.6 & 0.7 \\ -0.9 & 0.4 & -1.2 \\ 0.8 & -0.5 & 0.3 \end{bmatrix} \]

第一步: 計算啟用的 scale

對於每一行 (或者通道),計算其最大的絕對值

  • 第 1 行: 最大絕對值 = 1.0
  • 第 2 行: 最大絕對值 = 1.2
  • 第 3 行: 最大絕對值 = 0.8

計算每行的 scale:

\[\text{scale} = \begin{bmatrix} \frac{127}{1.0} \\ \frac{127}{1.2} \\ \frac{127}{0.8} \end{bmatrix} \begin{bmatrix} 127 \\ 105.83 \\ 158.75 \end{bmatrix} \]

步驟 2: 量化啟用矩陣

使用以下公式:

\[X_q = \text{clamp}_{[-128,127]}(\text{round}(X \times \text{scale})) \]

縮放相應的啟用值:

\[X \times \text{scale} = \begin{bmatrix} 1.0 \times 127 & -0.6 \times 127 & 0.7 \times 127 \\ -0.9 \times 105.83 & 0.4 \times 105.83 & -1.2 \times 105.83 \\ 0.8 \times 158.75 & -0.5 \times 158.75 & 0.3 \times 158.75 \end{bmatrix} \begin{bmatrix} 127 & -76.2 & 88.9 \\ -95.2 & 42.3 & -127 \\ 127 & -79.4 & 47.6 \end{bmatrix} \]

將值取整並截斷在 $ [-128, 127] $ 的範圍內:

\[X_q = \begin{bmatrix} 127 & -76 & 89 \\ -95 & 42 & -127 \\ 127 & -79 & 48 \end{bmatrix} \]

第三步: 反量化啟用

最後我們反量化啟用值:

\[X_{dequantized} = X_q \times \frac{1}{\text{scale}} \]

使用 scale 對值進行恢復:

\[X_{dequantized} = \begin{bmatrix} 127 \times \frac{1}{127} & -76 \times \frac{1}{127} & 89 \times \frac{1}{127} \\ -95 \times \frac{1}{105.83} & 42 \times \frac{1}{105.83} & -127 \times \frac{1}{105.83} \\ 127 \times \frac{1}{158.75} & -79 \times \frac{1}{158.75} & 48 \times \frac{1}{158.75} \end{bmatrix} \begin{bmatrix} 1.0 & -0.6 & 0.7 \\ -0.9 & 0.4 & -1.2 \\ 0.8 & -0.5 & 0.3 \end{bmatrix} \]


我們在量化啟用之前使用層歸一化 (Layer Normalization,LN) 以保持輸出的方差:

\[\text{LN}(x) = \frac{x - E(x)}{\sqrt{\text{Var}(x) + \epsilon}} \]

這裡 ε 是防止溢位的一個非常小的值

如前所述, round() 函式是不可微分的。我們使用 detach() 作為一個技巧,在反向傳播中實現可微分的 STE (Straight-Through Estimator):

# Adapted from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import torch.nn as nn
import torch.nn.functional as F

def activation_quant(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y
 
def weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u

class BitLinear(nn.Linear):
    """
    Only for training
    """
    def forward(self, x):
        w = self.weight
        x_norm = LN(x)
        
        # A trick for implementing Straight−Through−Estimator (STE) using detach()
        x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (weight_quant(w) - w).detach()
        
        # Perform quantized linear transformation
        y = F.linear(x_quant, w_quant)
        return y

推理

在推理過程中,我們只是將權重量化為三值,而不重新反量化。我們對啟用採用相同的方法,使用 8 位精度,然後使用高效的運算元執行矩陣乘法,接著透過權重和啟用的 scale 進行除法。這能夠顯著提高推理的速度,特別是在最佳化的硬體上。您可以看到,在訓練期間反量化的過程與推理不同,因為矩陣乘法保持在 fp16/bf16/fp32 中以進行正確的訓練。

# Adapted from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import torch.nn as nn
import torch.nn.functional as F

def activation_quant_inference(x):
    x = LN(x)
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127)
    return y, scale
 
class BitLinear(nn.Linear):
    """
    Only for training
    """
    def forward(self, x):
        w = self.weight # weights here are already quantized to (-1, 0, 1)
        w_scale = self.w_scale
        x_quant, x_scale = activation_quant_inference(x)
        y = efficient_kernel(x_quant, w) / w_scale / x_scale
        return y

1.58 位元的預訓練結果

在嘗試微調之前,我們首先嚐試復現 BitNet 論文中關於預訓練的結果。我們使用了一個小資料集 tinystories,以及一個 Llama3 8B 模型。我們發現,像論文中所做的那樣新增歸一化函式會提高效能。例如,在訓練 2000 步之後,我們在驗證集上的困惑度,沒有歸一化時為 6.3,使用歸一化後為 5.9。在這兩種情況下,訓練都是穩定的。

在有層歸一化 (藍色) 和沒有 (橙色) 的預訓練影像

在有層歸一化 (藍色) 和沒有 (橙色) 的預訓練影像

雖然這種方法在預訓練中看起來非常有趣,但只有少數機構能夠負擔大規模的預訓練。然而,因為存在有大量強大的預訓練模型,如果它們可以在預訓練後轉換為 1.58 位,將會非常有用。其他小組曾報告稱,微調的結果不如預訓練取得的結果那麼強大,因此我們展開了研究,看看我們是否能夠讓 1.58 位元地微調起作用。

1.58 位元的微調

當我們從預訓練的 Llama3 8B 權重開始微調時,模型表現略有提高,但並不如我們預期的那麼好。

Note: 所有的實驗都在 Nanotron 上進行,如果您對嘗試 1.58 位的預訓練或微調感興趣,可以檢視這個 PR 連結

微調曲線對比預訓練曲線

微調曲線對比預訓練曲線

為了理解原因,我們嘗試檢查隨機初始化模型和預訓練模型的權重分佈,以確定可能的問題。

隨機的權重分佈 (合併的標準差為 2)

隨機的權重分佈 (合併的標準差為 2)

預訓練 Llama3 的權重分佈

預訓練 Llama3 的權重分佈

兩個分佈的 scale 分別為:

隨機權重的 scale 分佈

預訓練權重的 scale 分佈

初始隨機權重分佈是兩個正態分佈的混合:

  • 一個標準差為 $$ 0.025 $$
  • 另一個標準差為 $$ \frac{0.025}{\sqrt{2 \cdot \text{num_hidden_layers}}} = 0.00325 $$

這是因為在 nanotron 中對列線性權重和行線性權重使用了不同的標準差。在量化版本中,所有矩陣只有兩個權重尺度 (50.25 和 402),這兩個尺度分別是每個矩陣權重的絕對值的倒數的平均值: scale = 1.0 / w.abs().mean().clamp_(min=1e-5)

  • 對於 $$\text{scale} = 50.25 $$,$$ w.abs().mean() = 0.0199 $$,導致 $$\text{std} = 0.025 $$,與我們的第一個標準差相匹配。用於推導標準差的公式基於 $$ |w| $$ 的半正態分佈的期望:

\[\mathbb{E}(|w|) = \text{std}(w) \cdot \sqrt{\frac{2}{\pi}} \]

  • 對於 $$\text{scale} = 402 $$,$$ w.abs().mean() = 0.0025 $$,導致 $$\text{std} = 0.00325 $$

另一方面,預訓練權重的分佈看起來像是一個標準差為 $ 0.013 $ 的正態分佈。

顯然,預訓練模型從更多資訊 (scale) 開始,而隨機初始化的模型從實際上沒有資訊開始,並隨著時間逐漸增加資訊。我們的結論是,從隨機權重開始給予模型最小的初始資訊,從而實現逐步學習過程,而在微調期間,引入 BitLinear 層會使模型喪失所有先前的資訊。

為了改善微調結果,我們嘗試了不同的技術。例如,我們嘗試過使用 per-row 和 per-column 量化而不是 per-tensor 量化,以保留更多來自 Llama 3 權重的資訊。我們還嘗試改變尺度計算的方式: 不再僅僅將權重的平均絕對值作為尺度,而是將異常值 (超過 k 倍平均絕對值的值,其中 k 是我們在實驗中嘗試變化的常數) 的平均絕對值作為尺度,但我們並沒有注意到明顯的改善。

def scale_outliers(tensor, threshold_factor=1):
    mean_absolute_value = torch.mean(torch.abs(tensor))
    threshold = threshold_factor * mean_absolute_value
    outliers = tensor[torch.abs(tensor) > threshold]
    mean_outlier_value = torch.mean(torch.abs(outliers))
    return mean_outlier_value

def weight_quant_scaling(w):
    scale = 1.0 / scale_outliers(w).clamp_(min=1e-5)
    quantized_weights = (w * scale).round().clamp_(-1, 1) / scale
    return quantized_weights

我們觀察到,隨機權重和 Llama 3 權重在損失開始時的數值約為 13,這表明當引入量化時,Llama 3 模型失去了所有先前的資訊。為了進一步研究模型在這個過程中失去了多少資訊,我們嘗試了 per-group 量化。

作為一個合理性檢查,我們首先將 group 大小設定為 1,這基本上意味著沒有量化。在這種情況下,損失從 1.45 開始,與正常微調時的情況相同。然而,當我們將組大小增加到 2 時,損失跳升到大約 11。這表明即使組大小最小為 2,模型仍幾乎失去了所有資訊。

為了解決這個問題,我們考慮逐漸引入量化而不是突然將其應用於每個張量的權重和啟用。為了實現這一點,我們引入了一個 lambda 值來控制這個過程:

lambda_ = ?
x_quant = x + lambda_ *(activation_quant(x) - x).detach()
w_quant = w + lambda_ *(weight_quant(w) - w).detach()

lambda 設定為 0 是 , 實際上沒有量化發生 , 當 lambda=1 時 , 將應用完全的量化 .

我們最初測試了一些離散的 lambda 值,比如 0.25、0.5、0.75 和 1。然而,這種方法並沒有在結果上帶來顯著的改善,主要是因為 lambda=0.25 已經足夠高,使損失開始得很高。

當 lambda = 0.25->0.5->0.75->1 時的微調影像

因此,我們決定嘗試一個根據訓練步驟動態調整的 lambda 值。

使用這種動態的 lambda 值導致更好的損失收斂,但在推理過程中,當 lambda 設定為 1 時,困惑度 (perplexity 或者 ppl) 的結果仍然遠非令人滿意。我們意識到這很可能是因為模型在 lambda=1 的情況下還沒有受過足夠長時間的訓練。為了解決這個問題,我們調整了我們的 lambda 值來改善訓練過程。

lambda_ = min(2 * training_step / total_training_steps, 1)

在這種配置下,經過 2000 步之後,我們有:

lambda = min(2*training_step/total_training_steps, 1) 時的微調影像

lambda = min(2*training_step/total_training_steps, 1) 時的微調影像

我們的微調方法整體上顯示出更好的收斂性。你可以觀察到在大約 1000 步時損失曲線略微增加,這對應於我們開始接近 lambda=1 或完全量化的時候。然而,在這一點之後,損失立即開始再次收斂,導致困惑度約為 4,得到了改善。

儘管取得了進展,但當我們在 WikiText 資料集上測試量化模型 (而不是我們用於微調的 tinystories 資料集) 時,困惑度非常高。這表明在特定資料集上以低位元模式微調模型會導致其喪失大部分通用知識。這個問題可能是因為我們在三值權重中追求的最小表示在不同資料集之間可能會有顯著差異。為解決這個問題,我們擴充套件了我們的訓練過程,包括了更大的 FineWeb-edu 資料集。我們保持了一個 lambda 值為:

lambda_ = min(training_step/1000, 1)

我們選擇了這個 lambda 值,因為它似乎是對模型進行 warmup 的一個很好的起點。然後,我們在 FineWeb-edu 資料集上使用學習率為 1e-4,訓練了 5000 步。訓練過程中使用了一個批次大小 (BS) 為 2B,總共訓練了 10B 個 token。

找到合適的學習率和合適的衰減率是具有挑戰性的; 這似乎是模型效能的一個關鍵因素。

在 Fineweb-edu 上進行 warmup 量化時的微調影像

在 Fineweb-edu 上進行 warmup 量化時的微調影像

在 FineWeb-Edu 上微調後,在 WikiText 資料集上達到 12.2 的困惑度是相當令人印象深刻的,考慮到我們只使用了 100 億個標記。其他評估指標也顯示出了強大的效能,考慮到資料量有限 (請參見結果)。

嘗試平滑 lambda 接近 1 時的急劇增加也是一個不錯的想法。為了實現這一點,考慮使用 lambda 排程器,這些排程器在開始時呈指數增長,然後在接近 1 時趨於平穩。這種方法可以幫助模型更平穩地適應 lambda 值的變化,避免突然的波動。

def scheduler(step, total_steps, k):
    normalized_step = step / total_steps
    return 1 - (1 - normalized_step)**k

對於不同的 k 值,總預熱步數為 1,我們有如下圖表:

不同 k 值時的指數排程器

我們使用表現最好的學習率 1e-4 進行了 4 次實驗 , 測試的 k 值分別為 4, 6, 8, 10.

使用不同指數排程器時的微調影像

使用不同指數排程器時的微調影像

平滑效果很好,不像線性排程器那樣出現尖峰。然而,困惑度並不理想,大約保持在 15 左右,對下游任務的表現也沒有改善。

我們還注意到了開始時的尖峰,模型難以從中恢復。當 lambda = 0 時,基本上沒有量化,所以損失開始很低,大約在 2 左右。但在第一步之後,出現了一個尖峰,類似於線性排程器的情況 (如上面的藍色圖表所示)。因此,我們嘗試了另一種排程器即 Sigmoid 排程器,它開始緩慢上升,迅速上升到 1,然後在接近 1 時趨於穩定。

def sigmoid_scheduler(step, total_steps, k):
    # Sigmoid-like curve: slow start, fast middle, slow end
    normalized_step = step / total_steps
    return 1 / (1 + np.exp(-k *(normalized_step - 0.5)))

對於不同的 k 值有以下的曲線:

對於不同 k 值的 Sigmoid 排程器

對於不同 k 值的 Sigmoid 排程器

我們這次在 k 為 15, 20, 25, 40 和 100 時進行了實驗:

使用 Sigmoid 排程器進行微調的影像

使用 Sigmoid 排程器進行微調的影像

lambda 的急劇增加導致在第 500 步左右出現不穩定,並沒有解決第一次發散問題。然而,對於 $$ k = 100 $$,我們觀察到在下游任務中有一些改善 (請參閱結果表),儘管困惑度仍保持在 13.5 左右。儘管如此,與線性排程器相比,並沒有顯示明顯的效能提升。

此外,我們嘗試了使用隨機權重和各種學習率從頭開始訓練模型的實驗。這使我們能夠比較我們的微調方法與傳統的預訓練方法的有效性。

不同學習率時的訓練影像

不同學習率時的訓練影像

所有從隨機權重訓練的模型都沒有比我們的微調模型表現更好。我們在這些模型中實現的最佳困惑度為 26,與我們的微調方法的結果相比略遜一籌。

擴充套件到 100B 個 token!

我們將實驗擴充套件到了 100B 個 token,以檢視是否能夠達到 Llama 3 8B 模型的效能水平。我們進行了更長時間的訓練執行,從較短執行中表現最佳的檢查點開始,使用線性排程器,並持續微調了 45,000 步。我們嘗試了不同的學習率,雖然在某些指標上模型的表現接近 Llama 3 模型,但平均而言,仍然落後一些。

這裡是我們在訓練過程中在不同 checkpoint 評估的一些指標的例子:

在訓練中不同學習率的多個指標評估結果

在訓練中不同學習率的多個指標評估結果

平均的分數如下:

在訓練中不同學習率的平均評估結果

在訓練中不同學習率的平均評估結果

在更小的模型上的實驗

在我們對 SmolLM 等較小模型進行的初始實驗中,我們觀察到 warmup 量化技術並沒有像對較大模型那樣帶來太多改進。這表明 warmup 量化的有效性可能與模型的大小和複雜性更密切相關。

例如,這裡是 SmolLM 135M 模型的損失曲線,比較了從一開始就使用 warmup 量化和完全量化的情況。有趣的是,這些曲線非常接近,得到的困惑度並沒有顯著不同。

有 warmup 量化和沒有時的 Smoll LLM 微調實驗

有 warmup 量化和沒有時的 Smoll LLM 微調實驗

對比與結論

BitNet 在與基準方法相比表現出色,特別是在較低位元數情況下。根據論文,BitNet 實現了與 8 位模型相當的分數,但推理成本顯著更低。在 4 位模型的情況下,僅量化權重的方法勝過同時量化權重和啟用的方法,因為啟用更難量化。然而,使用 1.58 位權重的 BitNet 超越了僅權重和權重與啟用量化方法。

下表展示了在 Llama3 8B 的 10B 個 token 微調過程之後各種指標的結果。這些結果與其他模型架構的結果進行了比較,以提供對效能的全面概述 (所有評估均使用 LightevalNanotron 格式模型上進行)。

與 Llama 模型的指標比較: 線性表示線性 lambda 排程器,Sigmoid 表示 Sigmoid 排程器 (在我們的情況下 k = 100)

與 Llama 模型的指標比較: 線性表示線性 lambda 排程器,Sigmoid 表示 Sigmoid 排程器 (在我們的情況下 k = 100)

在僅使用三值權重進行 10B 個 token 微調後,該模型展現出令人印象深刻的效能,特別是與經歷了更加廣泛訓練的其他模型相比。例如,它勝過了在資料集規模顯著大得多的 100B 個 token 上訓練的 Bitnet 7B 模型。此外,它的表現也優於 FBI LLM (Fully Binarized LLM) 模型,後者在更龐大的 1.26T 個 token 上進行了蒸餾。這突顯了該模型的效率和有效性,儘管其微調過程相對規模較小。

對於 100B 個 token 的實驗,我們擁有的表現最佳的 checkpoint 如下:

100B 個 token 微調後與 Llama 模型的指標比較

100B 個 token 微調後與 Llama 模型的指標比較

要複製這些結果,您可以檢視這個 PR 將模型轉換為 Nanotron 格式,解壓權重 (檢查函式 unpack_weights),並使用 lighteval。

請注意,儘管這些模型是從一個 Instruct-tuned 模型微調而來,它們仍需要使用 Instruct 資料集進行微調。這些可以被視為基礎模型。

使用的運算元和測試標準

為了從 BitNet 低精度權重中受益,我們將它們打包成一個 int8 張量 (這使得引數數量從 80 B 降至 28 B!)。在推理過程中,這些權重在執行矩陣乘法之前必須進行解包。我們在 Cuda 和 Triton 中實現了自定義核心,以處理矩陣乘法過程中的即時解包。對於矩陣乘法本身,我們採用了快取分塊矩陣乘法技術。為了充分理解這種方法,讓我們首先回顧一些 Cuda 程式設計基礎知識。

基礎的 GPU 概念: 執行緒、塊、和共享記憶體

在深入瞭解快取分塊矩陣乘法之前,瞭解一些基本的 GPU 概念是很重要的:

  • 執行緒 (thread) 和塊 (block): GPU 同時執行成千上萬個執行緒。這些執行緒被分組成塊,每個塊獨立執行。網格由這些塊 (grid) 組成,代表整個程式空間。例如,在矩陣乘法中,每個執行緒可能負責計算輸出矩陣的一個單元。
  • 共享記憶體 (share memory): 每個塊都可以訪問有限量的共享記憶體,比全域性記憶體 (global memory, GPU 上的主記憶體) 要快得多。然而,共享記憶體大小有限,並在塊內的所有執行緒之間共享。有效利用共享記憶體是提高 GPU 程式效能的關鍵。

矩陣乘法中的挑戰

在 GPU 上簡單實現矩陣乘法可能涉及每個執行緒透過直接從全域性記憶體讀取所需元素來計算結果矩陣的單個元素。然而,這種方法可能效率低下,原因如下:

  • 記憶體頻寬: 相對於 GPU 核心執行計算的速度,訪問全域性記憶體相對較慢。如果每個執行緒直接從全域性記憶體讀取矩陣元素,訪存時間可能成為瓶頸。
  • 冗餘資料訪問: 在矩陣乘法中,輸入矩陣的許多元素被多次使用。如果每個執行緒獨立從全域性記憶體獲取所需資料,相同的資料可能會被多次載入到 GPU 中,導致效率低下。例如,如果每個執行緒用於計算輸出矩陣中的單個元素,則負責計算位置 (i, j) 的執行緒將需要從全域性記憶體載入矩陣 A 的第 i 行和矩陣 B 的第 j 列。然而,其他執行緒,例如負責計算位置 (i+1, j) 的執行緒,無法重用這些資料,將不得不再次從全域性記憶體中載入相同的第 j 列。

分塊的概念

分塊是一種用於解決這些挑戰的技術,主要用於 FlashAttention 技術中以提高核心的效率。基本思想是將矩陣分成更小的子矩陣,稱為塊 (tile),這些塊可以適應 GPU 的共享記憶體。計算不再一次完成整個輸出矩陣,而是將計算分解為小塊,逐塊處理。

在矩陣乘法的背景下,這意味著將矩陣 A 和 B 劃分為塊,將這些塊載入到共享記憶體中,然後在這些較小的塊上執行乘法。這種方法允許執行緒重複使用儲存在快速共享記憶體中的資料,減少了重複訪問全域性記憶體的需求。

具體操作如下:

  • 將塊載入到共享記憶體: 每個執行緒塊協同地將矩陣 A 的一個小塊和相應的矩陣 B 的一個小塊從全域性記憶體載入到共享記憶體。這個操作對每個小塊只執行一次,然後該小塊被塊中的執行緒多次重複使用。
  • 計算部分乘積: 一旦塊載入到共享記憶體中,每個執行緒計算部分乘積。由於塊中的所有執行緒都在共享記憶體中的相同塊上工作,它們可以有效地重複使用資料,而無需額外訪問全域性記憶體。
  • 累積結果: 計算完一個塊的部分乘積後,執行緒將從矩陣 A 和 B 中載入下一個塊到共享記憶體,並重復這個過程。結果累積在暫存器 (或本地記憶體) 中,一旦所有塊都被處理,輸出矩陣元素的最終值將被寫回全域性記憶體。

分塊矩陣乘法圖示 (來源 https://cnugteren.github.io/tutorial/pages/page4.html)

分塊矩陣乘法圖示 (來源 https://cnugteren.github.io/tutorial/pages/page4.html)

現實的考慮

在實現快取分塊矩陣乘法時,考慮了幾個因素:

  • 塊大小: 塊的大小應該選擇以平衡能夠放入共享記憶體的資料量和全域性記憶體訪問次數之間的權衡。
  • 記憶體合併: 全域性記憶體訪問應該進行記憶體合併,這意味著相鄰的執行緒訪問相鄰的記憶體位置。
  • 佔用率: 應該選擇每個塊中的執行緒數和網格中的塊數,以確保高佔用率,即在 GPU 上有儘可能多的活動執行緒束 (warp) (一個執行緒束是一組 32 個執行緒),以隱藏記憶體延遲。

Triton 運算元

下面是我們作為基準的一個 triton 運算元:

@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
        GROUP_SIZE_M: tl.constexpr,
):

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)

    for i in range(4):
        b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
        for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K) ):
            k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j

            # BLOCK_SIZE_K must be a divisor of K / 4
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
            b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K // 4 - j * BLOCK_SIZE_K, other=0)
            mask = 3<<(2*i)
            b = ((b_uint8 & mask) >> (2*i))

            # We accumulate the tiles along the K dimension.
            tensor_full = tl.full((1,), 1, dtype=tl.int8)

            accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)

            a_ptrs += BLOCK_SIZE_K * stride_ak
            b_ptrs += BLOCK_SIZE_K * stride_bk

    c = accumulator

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

def matmul(a, b):
    assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
 _, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META:(triton.cdiv(M, META['BLOCK_SIZE_M'])* triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

程式碼解析

  1. 確定分塊位置

運算元首先確定每個執行緒塊負責的輸出矩陣的塊 (tile):

  • pid 是每個執行緒塊的唯一識別符號,使用 tl.program_id(axis=0) 獲得。
  • 網格被分成一組執行緒塊 (GROUP_SIZE_M )。每個組處理輸出矩陣的一部分。
  • pid_mpid_n 是分塊在 M 和 N 維度上的座標,分別表示。
  • 計算偏移量 (offs_amoffs_bnoffs_k ) 以確定每個塊中的執行緒將處理矩陣 A 和 B 的哪些元素。
  1. 載入和計算分塊

運算元使用迴圈以 BLOCK_SIZE_K 的塊大小迭代 K 維度。對於每個塊:

  • 載入分塊: 從全域性記憶體載入矩陣 A 和 B 的分塊。

  • 解包矩陣 B: 運算元假設矩陣 B 是使用 int8 值打包的,這意味著每個元素實際上代表四個較小的值打包成一個位元組。解壓過程發生在迴圈內:

    • 從全域性記憶體載入 b_uint8 作為打包的 int8
    • 解壓每個打包的值以獲得用於計算的實際權重值。
  • 點積: 核心計算從矩陣 A 和 B 載入的分塊的點積,並將結果累積到 accumulator 中。accumulator 儲存輸出矩陣 C 的分塊的部分結果。

  1. 儲存結果

在處理完沿著 K 維度的所有分塊之後,儲存在 accumulator 中的最終結果被轉換為 float16 ,並寫回到全域性記憶體中矩陣 C 的相應分塊。寫入過程使用掩碼來確定記憶體邊界,以確保只寫入有效元素。

要獲取程式碼的更詳細解釋,請檢視這個 PR

基準測試

我們對我們的運算元進行了基準測試,與使用 @torch.compile 解壓權重然後在 BF16 精度下執行矩陣乘法的方法進行了對比,發現兩種方法的效能幾乎相同。為了確保準確的基準測試,我們在 2000 次迭代中執行了矩陣乘法操作,並在最後 1000 次迭代中計算平均時間,以消除與初始載入或編譯相關的任何低效性。下面是顯示基準測試結果的圖表。我們還測試了各種矩陣大小,其中 x 軸表示對數尺度上的乘法次數,y 軸顯示平均時間 (以毫秒為單位)。

Triton 運算元對比 torch.compile

Triton 運算元對比 torch.compile

我們還嘗試使用 BitBlas,這是一個旨在使用混合精度執行矩陣運算的軟體庫。它透過允許在較低精度格式 (如 INT8、INT4,甚至 INT2) 而不是傳統的 FP32 或 FP16 格式中進行計算,來幫助最佳化這些操作。

基準測試結果令人鼓舞,如圖所示,BitBlas 在低精度下優於我們的自定義核心和 Torch 的 matmul 函式。

Bitblas 測試

Bitblas 測試

然而,在模型載入過程中,BitBlas 需要編譯適合權重矩陣形狀的核心,並將它們儲存在原生代碼庫中,這可能會增加初始載入時間。

結論

總之,隨著大型語言模型的不斷擴充套件,透過量化來減少它們的計算需求至關重要。本博文探討了 1.58 位量化的方法,該方法使用了三值權重。雖然在 1.58 位進行預訓練模型是資源密集型的,但我們已經證明,透過一些技巧,可以將現有模型微調到這個精度水平,實現高效的效能而不犧牲準確性。透過專門的核心最佳化推理速度,BitNet 為使大型語言模型更具實用性和可擴充套件性開啟了新的可能性。

致謝

我們要衷心感謝 Leandro von Werra、Thomas Wolf 和 Marc Sun 在整個專案中提供的寶貴幫助和見解。我們還要感謝 Omar Sanseviero 和 Pedro Cuenca 在完善這篇博文方面的貢獻,幫助我們清晰有效地向人工智慧社群傳達我們的發現。此外,我們要感謝 GeneralAI 團隊在 BitNet 專案上的開創性工作。他們的研究對我們的努力具有基礎性意義,我們特別感謝他們在論文中提供的清晰準確的資料。

更多資源

  1. H. Wang et al., BitNet: Scaling 1-bit Transformers for Large Language Models . arxiv paper
  2. S. Ma et al., The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits . arxiv paper
  3. S. Ma et al., The Era of 1-bit LLMs: Training Tips, Code and FAQ . link
  4. RJ. Honicky, Are All Large Language Models Really in 1.58 Bits? . blogpost
  5. L. Mao, CUDA Matrix Multiplication Optimization . blogpost
  6. Tutorial: OpenCL SGEMM tuning for Kepler . link
  7. CUDAMODE . github, youtube
  8. Wen-mei W. Hwu, David B. Kirk, Izzat El Hajj, Programming Massively Parallel Processors : A Hands-on Approach

相關文章