微調大模型,AMD MI300X就夠了!跟著這篇部落格微調Llama 3.1 405B,效果媲美H100

机器之心發表於2024-10-08
隨著 AI 模型的引數量越來越大,對算力的需求也水漲船高。

比如最近,Llama-3.1 登上了最強開源大模型的寶座,但超大杯 405B 版本的記憶體就高達 900 多 GB,這對算力構成了更加苛刻的挑戰。

如何降低算力的使用成本和使用門檻,已經成為許多公司尋求突破的關鍵。Felafax 就是其中的一家創業公司,致力於簡化 AI 訓練叢集的搭建流程。
圖片
Nikhil Sonti 和 Nikhin Sonti 創立了 Felafax,他們的口號是在構建開源 AI 平臺,為下一代 AI 硬體服務,將機器學習的訓練成本降低 30%。

與英偉達相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的價效比,按每美元計算,其效能表現更為出色。

最近,Felafax 的聯合創始人 Nikhil Sonti 釋出了一篇部落格,詳細分享瞭如何透過 8 張 AMD MI300X GPU 和 JAX 微調 LLaMA 3.1 405B 模型的方法,所有程式碼現已開源。

圖片

Github 連結:https://github.com/felafax/felafax

機器之心對部落格內容進行了不改變原意的編譯、整理,以下是部落格內容:

JAX 尤其適合非英偉達硬體

JAX 是一個強大的機器學習庫,結合了類似 NumPy 的 API、自動微分功能以及 Google 的 XLA 編譯器。它在模型並行化方面提供了優秀的 API,因此非常適合像 LLaMA 3.1 405B 這樣的超大模型訓練。

在使用 AMD 硬體時,JAX 有幾個明顯的優勢:

  • 多硬體並行支援:JAX 採用 XLA(加速線性代數)編譯器,將計算編譯為硬體無關的中間表示(HLO),這意味著同樣的 JAX 程式碼無需修改便可高效執行在不同硬體後端,包括 AMD GPU。
  • 獨立於底層硬體:XLA 編譯器的最佳化策略是通用的,不針對某個特定的硬體平臺。這使得任何支援 XLA 的硬體裝置(如 CPU、GPU、TPU)都能受益於這些最佳化,獲得更好的效能表現。
  • 極高的適應性:從 NVIDIA 轉移到 AMD(或其他硬體)時,JAX 只需做極少的程式碼改動。而相較之下,PyTorch 與英偉達的 CUDA 生態系統緊密耦合,遷移過程相對複雜。

因此,JAX 成為了我們在非英偉達硬體上的最佳選擇。

拉取 Docker 映象:
docker pull rocm/jax:latest
啟動 Docker 容器:
# Pull the Docker Image:
docker pull rocm/jax:latest 

# Start the Docker Container:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest

# Verify the Installation: 
python3 -c 'import jax; print(jax.devices())'
驗證安裝
python3 -c 'import jax; print (jax.devices ())'
訓練使用了一個配備了 8 張 AMD MI300x GPU 的 AMD 節點。每張 MI300x 擁有 192GB 的 HBM3 記憶體,效能表現與最新的英偉達 H100 GPU 相比非常出色。
圖片
與英偉達 H100 的比較,來源:TensorWave

訓練 LLaMA 405B:效能與可擴充套件性

使用 JAX,可以成功地在 AMD GPU 上訓練 LLaMA 405B 模型。我們使用 LoRA 微調,將所有模型權重和 LoRA 引數都設為 bfloat16,LoRA rank 設為 8,LoRA alpha 設為 16:

  • 模型大小:LLaMA 模型的權重佔用了約 800GB 的視訊記憶體。
  • LoRA 權重 + 最佳化器狀態:大約佔用了 400GB 的視訊記憶體。
  • 視訊記憶體總使用量:佔總視訊記憶體的 77%,約 1200GB。
  • 限制:由於 405B 模型的規模過大,batch 大小和序列長度的空間有限,使用的 batch size 為 16,序列長度為 64。
  • JIT 編譯:由於空間限制,無法執行 JIT 編譯版本;它可能需要比急切模式稍多的空間。
  • 訓練速度:使用 JAX 急切模式,約為 35 tokens / 秒。
  • 記憶體效率:穩定在約 70% 左右。
  • 擴充套件性:在 8 張 GPU 上,使用 JAX 的擴充套件性接近線性。

由於硬體和視訊記憶體的限制,我們無法執行 JIT 編譯版本的 405B 模型,整個訓練過程是在 JAX 的急切模式下執行的,因此還有很大的進步空間。

下圖中顯示了在一次微調訓練步驟中,8 張 GPU 的視訊記憶體利用率和 rocm-smi 輸出:

GPU 利用率:
圖片
視訊記憶體利用率:
圖片
rocm-smi 輸出:

圖片

訓練設定

將 LLaMA 3.1 從 PyTorch 移植到 JAX
圖片
此前,Nikhil Sonti 分享過如何將 LLaMA 3.1 從 PyTorch 移植到 JAX。他指出,目前 90% 的大型語言模型(LLM)都執行在 NVIDIA GPU 上,但實際上還有一些同樣強大且價效比更高的替代方案。例如,在 Google TPU 上訓練和部署 Llama 3.1 的成本比 NVIDIA GPU 低約 30%。

然而,支援非 NVIDIA 硬體的開發工具較為匱乏。Sonti 最初嘗試使用 PyTorch XLA 在 TPU 上訓練 Llama 3.1,但過程並不順利。XLA 與 PyTorch 的整合不夠完善,缺少一些關鍵的庫(如 bitsandbytes 無法正常執行),同時還遇到了一些難以解決的 HuggingFace 錯誤。

為此,他決定調整策略,將 Llama 3.1 從 PyTorch 移植到 JAX,成功解決了這些問題。Sonti 還錄製了詳細的教程影片,並開源了所有程式碼:

圖片

  • 方法演示:https://dub.sh/felafax-demo
  • 程式碼倉庫:https://github.com/felafax/felafax

載入模型,並把模型引數分片
處理像 LLaMA 405B 這樣的超大模型,需要在多個裝置之間高效地進行引數分片。以下是如何透過 JAX 實現這一點的。

在 JAX 中進行引數分片

為了將巨大的 LLaMA 405B 模型高效地分佈到 8 張 AMD GPU 上,需要使用 JAX 的裝置網格(device mesh)功能。

部署程式碼:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69

JAX 的裝置網格可以幫助我們把可用的裝置組織成一個網格,讓我們可以指定如何把模型的引數和計算分配到不同的 GPU 上。

在本文的設定中,需要建立一個形狀為(1, 8, 1)的網格,並將軸分別命名為資料並行(dp)、全分片資料並行(fsdp)和模型並行(mp)。然後,為模型的每個張量定義特定的分片規則,指定這些維度如何沿著這些網格軸進行分片。
DEVICES = jax.devices () 
DEVICE_COUNT = len (DEVICES) 
DEVICE_MESH = mesh_utils.create_device_mesh ((1, 8, 1)) 
MESH = Mesh (devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))
視覺化分片

可以使用以下程式碼來視覺化分片結果,從而方便地驗證分片規則是否按預期應用。
jax.debug.visualize_array_sharding 
分片規則

模型不同元件的分片規則如下所示:

  • 引數如何分片:

引數要在 8 個 GPU 之間分配。例如,LM head(lm_head/kernel)張量有兩個軸,按照 PS ("fsdp", "mp") 進行分片。在本例中是 8 和 1,因此可以看到該張量在第一個軸上沿著 8 個 GPU 被拆分。

  • Non-Replicated 引數:

沒有任何分片規範的引數會在所有裝置上進行復制。例如,層歸一化(attention_norm/kernel 和 ffn_norm/kernel)沒有設定分片規範,是 PS (None)。

應用分片函式
在載入模型時,使用以下分片函式逐步對模型權重進行分片:
def make_shard_and_gather_fns (partition_specs):
    def make_shard_fn (partition_spec):
        out_sharding = NamedSharding (mesh, partition_spec)
        def shard_fn (tensor):
            return jax.device_put (tensor, out_sharding).block_until_ready ()
        return shard_fn

    shard_fns = jax.tree_util.tree_map (make_shard_fn, partition_specs)
    return shard_fns

# Create shard functions based on partitioning rules
shard_fns = make_shard_and_gather_fns (partitioning_rules)

這使得我們能夠將每個引數放置在指定的裝置上,並按照設定的分片進行處理。

分片訓練 Batch

最初,訓練 Batch 是正常建立的,但在輸入模型之前,需要按照下面的程式碼在 GPU 上進行分片:
train_batch = jax.device_put ( train_batch,
NamedSharding (self.mesh, PS ("dp", "fsdp")))

在這裡,我們指定訓練 Batch 應該在 "dp" 和 "fsdp" 軸上進行分片,在本例中分別對應於被分成 1 和 8 份,如果把結果視覺化出來,如下所示:

分片前:
圖片
在呼叫 jax.device_put 之後:
圖片
加入 LoRA

LoRA 透過將權重更新分解為低秩矩陣,減少了可訓練引數的數量,這對於微調大型模型特別有效。以下是在 AMD GPU 上微調 Llama 3.1-405 的 LoRA 的要點:

  • 將 LoRA 引數(lora_a 和 lora_b)與主模型引數分開。
  • 使用 jax.lax.stop_gradient (kernel) 來防止對主模型權重的更新。
  • 使用 lax.dot_general 進行快速、精確控制的矩陣運算。
  • LoRA 輸出在新增到主輸出之前會被縮放為 (self.lora_alpha/self.lora_rank)。

LoRADense 層

在此設定一個自定義的 LoRADense 層,該層整合了 LoRA 引數:
class LoRADense (nn.Module):
    features: int
    lora_rank: int = 8
    lora_alpha: float = 16.0
@nn.compact
def __call__(self, inputs: Any) -> Any:
# Original kernel parameter (frozen)
        kernel = self.param ('kernel', ...)
        y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)
# LoRA parameters (trainable)
        lora_a = self.variable ('lora_params', 'lora_a', ..., ...)
        lora_b = self.variable ('lora_params', 'lora_b', ..., ...)
# Compute LoRA output
        lora_output = lax.dot_general (inputs, lora_a.value, ...)
        lora_output = lax.dot_general (lora_output, lora_b.value, ...)
# Combine original output with LoRA modifications
        y += (self.lora_alpha/self.lora_rank) * lora_output

        return y.astype (self.dtype)

分片 LoRA 引數

為了高效地在裝置之間分配 LoRA 引數,我們也透過 JAX 設定了分片規則,這確保了 LoRA 引數與主模型引數的分片一致,最佳化了記憶體使用和計算效率。
LoRA A matrices (lora_a)

LoRA A 矩陣(lora_a)

  • 分片規則:PS ("fsdp", "mp")
  • 視覺化結果:如下圖所示,lora_a 引數被分片為 (8, 1),這意味著第一個軸在 8 個裝置上進行分片("fsdp" 軸),而第二個軸未進行分片。
圖片
LoRA B 矩陣(lora_b)

  • 分片規則:PS ("mp", "fsdp")
  • 視覺化結果:如下圖所示,lora_b 引數被分片為 (1, 8),這意味著第二個軸在 8 個裝置上進行分片(fsdp 軸),而第一個軸未進行分片。
圖片
這種分片策略最佳化了引數的分配,減少了通訊開銷,並在訓練過程中增強了並行性。它確保每個裝置僅持有一部分 LoRA 引數,使得大模型如 LLaMA 405B 的高效擴充套件成為可能。

僅更新 LoRA 引數

為了最佳化訓練,在微調 LLaMA 405B 模型,只計算 LoRA 引數的梯度,保持主模型引數不變。這個方法減少了記憶體使用,並加速了訓練,因為只更新較少的引數。可以移步 GitHub 倉庫,檢視實現細節。

在訓練過程中,每一步都涉及將一批輸入資料透過模型進行處理。由於只有 LoRA 引數是可訓練的,因此模型的預測和計算的損失僅依賴於這些引數,然後對 LoRA 引數進行反向傳播。只更新這些引數簡化了訓練過程,使得在多個 GPU 上高效微調像 LLaMA 405B 這樣的大型模型成為可能。

更多研究細節,請參考原部落格。

相關文章