英偉達玩轉剪枝、蒸餾:把Llama 3.1 8B引數減半,效能同尺寸更強

机器之心發表於2024-08-16
小模型崛起了。

上個月,Meta 釋出了 Llama 3.1 系列模型,其中包括 Meta 迄今為止最大的 405B 模型,以及兩個較小的模型,引數量分別為 700 億和 80 億。

Llama 3.1 被認為是引領了開源新時代。然而,新一代的模型雖然效能強大,但部署時仍需要大量計算資源。

因此,業界出現了另一種趨勢,即開發小型語言模型 (SLM),這種模型在許多語言任務中表現足夠出色,部署起來也非常便宜。

最近,英偉達研究表明,結構化權重剪枝知識蒸餾相結合,可以從初始較大的模型中逐步獲得較小的語言模型

圖片

圖靈獎得主、Meta 首席 AI 科學家 Yann LeCun 也點贊轉帖了該研究。

經過剪枝和蒸餾,英偉達研究團隊將 Llama 3.1 8B 提煉為 Llama-3.1-Minitron 4B 開源了出來。這是英偉達在 Llama 3.1 開源系列中的第一個作品。

Llama-3.1-Minitron 4B 的表現優於類似大小的最先進的開源模型,包括 Minitron 4B、Phi-2 2.7B、Gemma2 2.6B 和 Qwen2-1.5B。

圖片

這項研究的相關論文早在上個月已經放出了。

圖片
  • 論文連結:https://www.arxiv.org/pdf/2407.14679

  • 論文標題:Compact Language Models via Pruning and Knowledge Distillation

剪枝和蒸餾

剪枝使模型變得更小、更精簡,可以透過刪除層(深度剪枝)或刪除神經元和注意力頭以及嵌入通道(寬度剪枝)來實現。剪枝通常伴隨著一定程度的再訓練,以恢復準確率

模型蒸餾是一種將知識從大型複雜模型(通常稱為教師模型)遷移到較小、較簡單的學生模型的技術。目標是建立一個更高效的模型,該模型保留了原始較大模型的大部分預測能力,同時執行速度更快且資源消耗更少。

蒸餾方式主要包括兩種:SDG 微調與經典知識蒸餾,這兩種蒸餾方式互補。本文主要關注經典知識蒸餾方法。

英偉達採用將剪枝與經典知識蒸餾相結合的方式來構造大模型,下圖展示了單個模型的剪枝和蒸餾過程(上)以及模型剪枝和蒸餾的鏈條(下)。具體過程如下:

1. 英偉達從 15B 模型開始,評估每個元件(層、神經元、頭和嵌入通道)的重要性,然後對模型進行排序和剪枝,使其達到目標大小:8B 模型。

2. 接著使用模型蒸餾進行了輕度再訓練,原始模型作為老師,剪枝後的模型作為學生。

3. 訓練結束後,以小模型(8B)為起點,剪枝和蒸餾為更小的 4B 模型。

圖片

從 15B 模型進行剪枝與蒸餾的過程。

需要注意的點是,在對模型剪枝之前,需要先了解模型的哪部分是重要的。英偉達提出了一種基於啟用的純重要性評估策略,該策略可以同時計算所有相關維度(深度、神經元、頭和嵌入通道)的資訊,使用一個包含 1024 個樣本的小型校準資料集,並且只需要前向傳播。這種方法相比依賴梯度資訊並需要反向傳播的策略更加簡單且具有成本效益。

剪枝過程中,你可以針對給定軸或軸組合在剪枝和重要性估計之間進行迭代交替。實證研究顯示,使用單次重要性估計就足夠了,迭代估計不會帶來額外的好處。

利用經典知識蒸餾進行重新訓練

下圖 2 展示了蒸餾過程,其中 N 層學生模型(剪枝後的模型)是從 M 層教師模型中(原始未剪枝模型)蒸餾而來。學生模型透過最小化嵌入輸出損失、logit 損失以及對映到學生塊 S 和教師塊 T 的 Transformer 編碼器特定損失組合來學習。

圖片

圖 2:蒸餾訓練損失。

剪枝和蒸餾最佳實踐

英偉達基於緊湊語言模型剪枝知識蒸餾的廣泛消融研究,將自己的學習成果總結為以下幾種結構化壓縮最佳實踐。

一是調整大小。

  • 要訓練一組 LLM,首先訓練最大的一個,然後迭代地剪枝和蒸餾以獲得較小的 LLM。

  • 如果使用多階段訓練策略來訓練最大的模型,最好剪枝並對訓練最後階段獲得的模型進行重新訓練。

  • 對最接近目標大小的可用源模型進行剪枝

二是剪枝

  • 優先考慮寬度剪枝而不是深度剪枝,這對於 15B 引數規模以下的模型效果很好。

  • 使用單樣本(single-shot)重要性估計,因為迭代重要性估計沒有任何好處。

三是重新訓練。

  • 僅使用蒸餾損失進行重新訓練,而不是常規訓練。

  • 當深度明顯減少時,使用 logit、中間狀態和嵌入蒸餾。

  • 當深度沒有明顯減少時,使用 logit-only 蒸餾。

Llama-3.1-Minitron:將最佳實踐付諸應用

Meta 最近推出了功能強大的 Llama 3.1 開源模型系列,在許多基準測試中可與閉源模型相媲美。Llama 3.1 的引數範圍從巨大的 405B 到 70B、8B。

憑藉 Nemotron 蒸餾的經驗,英偉達著手將 Llama 3.1 8B 模型蒸餾為更小、更高效的 4B 模型,採取以下措施:

  • 教師微調

  • Depth-only 剪枝

  • Width-only 剪枝

  • 準確率基準

  • 效能基準

教師微調

為了糾正模型訓練所基於的原始資料集的分佈偏差,英偉達首先在他們的資料集上(94B token)對未剪枝的 8B 模型進行了微調。實驗表明,如果不糾正分佈偏差,教師模型在蒸餾時會為資料集提供次優指導。

Depth-only 剪枝

為了從 8B 降到 4B,英偉達剪枝了 16 層(50%)。他們首先透過從模型中刪除每個層或連續子層組來評估它們的重要性,並觀察下游任務中 LM 損失的增加或準確率的降低。

下圖 5 顯示了刪除 1、2、8 或 16 層後驗證集上的 LM 損失值。例如,第 16 層的紅色圖表示如果刪除前 16 層,則出現 LM 損失。第 17 層表示如果保留第一層並刪除第 2 至第 17 層,也出現 LM 損失。英偉達觀察到:開始和結束的層是最重要的。

圖片

圖 5:depth-only 剪枝中層的重要性。

然而,英偉達觀察到,這種 LM 損失不一定與下游效能直接相關。

下圖 6 顯示了每個剪枝模型的 Winogrande 準確率,它表明最好刪除第 16 到第 31 層,其中第 31 層是倒數第二層,剪枝模型的 5-shot 準確率明顯高於隨機準確率 (0.5)。英偉達採納了這一見解,刪除了第 16 到第 31 層。

圖片

圖 6:當刪除 16 層時,在 Winogrande 任務上的準確率

Width-only 剪枝

英偉達沿寬度軸剪枝了嵌入(隱藏)和 MLP 中間維,以壓縮 Llama 3.1 8B。具體來說,他們使用前面描述的基於啟用的策略來計算每個注意頭、嵌入通道和 MLP 隱藏維度的重要性分數。

在重要性估計之後,英偉達選擇

  • 將 MLP 中間維從 14336 剪枝到 9216。

  • 將隱藏大小從 4096 剪枝到 3072。

  • 重新訓練注意頭數量和層數。

值得一提的是,在單樣本剪枝之後,寬度剪枝的 LM 損失高於深度剪枝。然而,經過短暫的重新訓練後,趨勢發生了逆轉。

準確率基準

英偉達使用以下引數對模型進行蒸餾

  • 峰值學習率 = 1e-4

  • 最小學習率 = 1e-5

  • 40 步線性預熱

  • 餘弦衰減計劃

  • 全域性批次大小 = 1152

下表 1 顯示了 Llama-3.1-Minitron 4B 模型變體(寬度剪枝和深度剪枝)與原始 Llama 3.1 8B 模型、其他類似大小的模型在跨多個領域的基準測試中的效能比較。總體而言,英偉達再次證實了寬度剪枝策略相較於遵循最佳實踐的深度剪枝的有效性。

圖片

表 1:Minitron 4B base 模型相較於類似規模 base 模型的準確率比較。

為了驗證蒸餾後的模型是否可以成為強大的指令模型,英偉達使用 NeMo-Aligner 對 Llama-3.1-Minitron 4B 模型進行了微調。

他們使用了 Nemotron-4 340B 的訓練資料,在 IFEval、MT-Bench、ChatRAG-Bench 和 Berkeley Function Calling Leaderboard (BFCL) 上進行了評估,以測試指令遵循、角色扮演、RAG 和函式呼叫功能。最後確認 Llama-3.1-Minitron 4B 模型可以成為可靠的指令模型,其表現優於其他基線 SLM。

圖片

表 2:對齊 Minitron 4B base 模型與類似規模的對齊模型的準確率比較。

效能基準

英偉達利用 NVIDIA TensorRT-LLM(一種用於最佳化 LLM 推理的開源工具包)最佳化了 Llama 3.1 8B 和 Llama-3.1-Minitron 4B 模型。

下兩張圖顯示了不同模型在不同用例下以 FP8 和 FP16 精度每秒的吞吐量請求,表示為 8B 模型的 batch size 為 32 的輸入序列長度 / 輸出序列長度 (ISL/OSL) 組合以及 4B 模型的 batch size 為 64 的輸入序列長度 / 輸出序列長度 (ISL/OSL) 組合,這要歸功於在一塊英偉達 H100 80GB GPU 上,較小的權重允許較大的 batch size。

Llama-3.1-Minitron-4B-Depth-Base 變體是最快的,平均吞吐量約為 Llama 3.1 8B 的 2.7 倍,而 Llama-3.1-Minitron-4B-Width-Base 變體的平均吞吐量約為 Llama 3.1 8B 的 1.8 倍。與 BF16 相比,在 FP8 中部署還可使這三種型號的效能提高約 1.3 倍。

圖片
圖片

圖 8:組合:Llama 3.1 8B 為 BS=32,Llama-3.1-Minitron 4B 型號為 BS=64。1x H100 80GB GPU。

結論

剪枝和經典知識提煉是一種非常經濟高效的方法,可以逐步獲得更小尺寸的 LLM,與在所有領域從頭開始訓練相比,可實現更高的準確性。與合成資料式微調或從頭開始預訓練相比,這是一種更有效且資料效率更高的方法。

Llama-3.1-Minitron 4B 是英偉達首次嘗試使用最先進的開源 Llama 3.1 系列完成的探索。要在 NVIDIA NeMo 中使用 Llama-3.1 的 SDG 微調,可參閱 GitHub 上的 /sdg-law-title-generation 部分。

有關更多資訊,請參閱以下資源:

  • https://arxiv.org/abs/2407.14679

  • https://github.com/NVlabs/Minitron

  • https://huggingface.co/nvidia/Llama-3.1-Minitron-4B-Width-Base

  • https://huggingface.co/nvidia/Llama-3.1-Minitron-4B-Depth-Base

參考連結:

https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/

相關文章