隨著 AI 模型的引數量越來越大,對算力的需求也水漲船高。比如最近,Llama-3.1 登上了最強開源大模型的寶座,但超大杯 405B 版本的記憶體就高達 900 多 GB,這對算力構成了更加苛刻的挑戰。如何降低算力的使用成本和使用門檻,已經成為許多公司尋求突破的關鍵。Felafax 就是其中的一家創業公司,致力於簡化 AI 訓練叢集的搭建流程。 Nikhil Sonti 和 Nikhin Sonti 創立了 Felafax,他們的口號是在構建開源 AI 平臺,為下一代 AI 硬體服務,將機器學習的訓練成本降低 30%。與英偉達相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的價效比,按每美元計算,其效能表現更為出色。最近,Felafax 的聯合創始人 Nikhil Sonti 釋出了一篇部落格,詳細分享瞭如何透過 8 張 AMD MI300X GPU 和 JAX 微調 LLaMA 3.1 405B 模型的方法,所有程式碼現已開源。Github 連結:https://github.com/felafax/felafax機器之心對部落格內容進行了不改變原意的編譯、整理,以下是部落格內容:JAX 是一個強大的機器學習庫,結合了類似 NumPy 的 API、自動微分功能以及 Google 的 XLA 編譯器。它在模型並行化方面提供了優秀的 API,因此非常適合像 LLaMA 3.1 405B 這樣的超大模型訓練。在使用 AMD 硬體時,JAX 有幾個明顯的優勢:- 多硬體並行支援:JAX 採用 XLA(加速線性代數)編譯器,將計算編譯為硬體無關的中間表示(HLO),這意味著同樣的 JAX 程式碼無需修改便可高效執行在不同硬體後端,包括 AMD GPU。
- 獨立於底層硬體:XLA 編譯器的最佳化策略是通用的,不針對某個特定的硬體平臺。這使得任何支援 XLA 的硬體裝置(如 CPU、GPU、TPU)都能受益於這些最佳化,獲得更好的效能表現。
- 極高的適應性:從 NVIDIA 轉移到 AMD(或其他硬體)時,JAX 只需做極少的程式碼改動。而相較之下,PyTorch 與英偉達的 CUDA 生態系統緊密耦合,遷移過程相對複雜。
因此,JAX 成為了我們在非英偉達硬體上的最佳選擇。docker pull rocm/jax:latest
# Pull the Docker Image:
docker pull rocm/jax:latest
# Start the Docker Container:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest
# Verify the Installation:
python3 -c 'import jax; print(jax.devices())'
python3 -c 'import jax; print (jax.devices ())'
訓練使用了一個配備了 8 張 AMD MI300x GPU 的 AMD 節點。每張 MI300x 擁有 192GB 的 HBM3 記憶體,效能表現與最新的英偉達 H100 GPU 相比非常出色。 與英偉達 H100 的比較,來源:TensorWave使用 JAX,可以成功地在 AMD GPU 上訓練 LLaMA 405B 模型。我們使用 LoRA 微調,將所有模型權重和 LoRA 引數都設為 bfloat16,LoRA rank 設為 8,LoRA alpha 設為 16:- 模型大小:LLaMA 模型的權重佔用了約 800GB 的視訊記憶體。
- LoRA 權重 + 最佳化器狀態:大約佔用了 400GB 的視訊記憶體。
- 視訊記憶體總使用量:佔總視訊記憶體的 77%,約 1200GB。
- 限制:由於 405B 模型的規模過大,batch 大小和序列長度的空間有限,使用的 batch size 為 16,序列長度為 64。
- JIT 編譯:由於空間限制,無法執行 JIT 編譯版本;它可能需要比急切模式稍多的空間。
- 訓練速度:使用 JAX 急切模式,約為 35 tokens / 秒。
- 擴充套件性:在 8 張 GPU 上,使用 JAX 的擴充套件性接近線性。
由於硬體和視訊記憶體的限制,我們無法執行 JIT 編譯版本的 405B 模型,整個訓練過程是在 JAX 的急切模式下執行的,因此還有很大的進步空間。 下圖中顯示了在一次微調訓練步驟中,8 張 GPU 的視訊記憶體利用率和 rocm-smi 輸出:將 LLaMA 3.1 從 PyTorch 移植到 JAX 此前,Nikhil Sonti 分享過如何將 LLaMA 3.1 從 PyTorch 移植到 JAX。他指出,目前 90% 的大型語言模型(LLM)都執行在 NVIDIA GPU 上,但實際上還有一些同樣強大且價效比更高的替代方案。例如,在 Google TPU 上訓練和部署 Llama 3.1 的成本比 NVIDIA GPU 低約 30%。然而,支援非 NVIDIA 硬體的開發工具較為匱乏。Sonti 最初嘗試使用 PyTorch XLA 在 TPU 上訓練 Llama 3.1,但過程並不順利。XLA 與 PyTorch 的整合不夠完善,缺少一些關鍵的庫(如 bitsandbytes 無法正常執行),同時還遇到了一些難以解決的 HuggingFace 錯誤。為此,他決定調整策略,將 Llama 3.1 從 PyTorch 移植到 JAX,成功解決了這些問題。Sonti 還錄製了詳細的教程影片,並開源了所有程式碼:- 方法演示:https://dub.sh/felafax-demo
- 程式碼倉庫:https://github.com/felafax/felafax
處理像 LLaMA 405B 這樣的超大模型,需要在多個裝置之間高效地進行引數分片。以下是如何透過 JAX 實現這一點的。為了將巨大的 LLaMA 405B 模型高效地分佈到 8 張 AMD GPU 上,需要使用 JAX 的裝置網格(device mesh)功能。部署程式碼:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69JAX 的裝置網格可以幫助我們把可用的裝置組織成一個網格,讓我們可以指定如何把模型的引數和計算分配到不同的 GPU 上。在本文的設定中,需要建立一個形狀為(1, 8, 1)的網格,並將軸分別命名為資料並行(dp)、全分片資料並行(fsdp)和模型並行(mp)。然後,為模型的每個張量定義特定的分片規則,指定這些維度如何沿著這些網格軸進行分片。DEVICES = jax.devices ()
DEVICE_COUNT = len (DEVICES)
DEVICE_MESH = mesh_utils.create_device_mesh ((1, 8, 1))
MESH = Mesh (devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))
可以使用以下程式碼來視覺化分片結果,從而方便地驗證分片規則是否按預期應用。jax.debug.visualize_array_sharding
引數要在 8 個 GPU 之間分配。例如,LM head(lm_head/kernel)張量有兩個軸,按照 PS ("fsdp", "mp") 進行分片。在本例中是 8 和 1,因此可以看到該張量在第一個軸上沿著 8 個 GPU 被拆分。沒有任何分片規範的引數會在所有裝置上進行復制。例如,層歸一化(attention_norm/kernel 和 ffn_norm/kernel)沒有設定分片規範,是 PS (None)。在載入模型時,使用以下分片函式逐步對模型權重進行分片:def make_shard_and_gather_fns (partition_specs):
def make_shard_fn (partition_spec):
out_sharding = NamedSharding (mesh, partition_spec)
def shard_fn (tensor):
return jax.device_put (tensor, out_sharding).block_until_ready ()
return shard_fn
shard_fns = jax.tree_util.tree_map (make_shard_fn, partition_specs)
return shard_fns
# Create shard functions based on partitioning rules
shard_fns = make_shard_and_gather_fns (partitioning_rules)
這使得我們能夠將每個引數放置在指定的裝置上,並按照設定的分片進行處理。最初,訓練 Batch 是正常建立的,但在輸入模型之前,需要按照下面的程式碼在 GPU 上進行分片:train_batch = jax.device_put ( train_batch,
NamedSharding (self.mesh, PS ("dp", "fsdp")))
在這裡,我們指定訓練 Batch 應該在 "dp" 和 "fsdp" 軸上進行分片,在本例中分別對應於被分成 1 和 8 份,如果把結果視覺化出來,如下所示:LoRA 透過將權重更新分解為低秩矩陣,減少了可訓練引數的數量,這對於微調大型模型特別有效。以下是在 AMD GPU 上微調 Llama 3.1-405 的 LoRA 的要點:- 將 LoRA 引數(lora_a 和 lora_b)與主模型引數分開。
- 使用 jax.lax.stop_gradient (kernel) 來防止對主模型權重的更新。
- 使用 lax.dot_general 進行快速、精確控制的矩陣運算。
- LoRA 輸出在新增到主輸出之前會被縮放為 (self.lora_alpha/self.lora_rank)。
在此設定一個自定義的 LoRADense 層,該層整合了 LoRA 引數:class LoRADense (nn.Module):
features: int
lora_rank: int = 8
lora_alpha: float = 16.0
@nn.compac
t
def __call__(self, inputs: Any) -> Any:
# Original kernel parameter (frozen)
kernel = self.param ('kernel', ...)
y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)
# LoRA parameters (trainable)
lora_a = self.variable ('lora_params', 'lora_a', ..., ...)
lora_b = self.variable ('lora_params', 'lora_b', ..., ...)
# Compute LoRA output
lora_output = lax.dot_general (inputs, lora_a.value, ...)
lora_output = lax.dot_general (lora_output, lora_b.value, ...)
# Combine original output with LoRA modifications
y += (self.lora_alpha/self.lora_rank) * lora_output
return y.astype (self.dtype)
為了高效地在裝置之間分配 LoRA 引數,我們也透過 JAX 設定了分片規則,這確保了 LoRA 引數與主模型引數的分片一致,最佳化了記憶體使用和計算效率。- 視覺化結果:如下圖所示,lora_a 引數被分片為 (8, 1),這意味著第一個軸在 8 個裝置上進行分片("fsdp" 軸),而第二個軸未進行分片。
- 視覺化結果:如下圖所示,lora_b 引數被分片為 (1, 8),這意味著第二個軸在 8 個裝置上進行分片(fsdp 軸),而第一個軸未進行分片。
這種分片策略最佳化了引數的分配,減少了通訊開銷,並在訓練過程中增強了並行性。它確保每個裝置僅持有一部分 LoRA 引數,使得大模型如 LLaMA 405B 的高效擴充套件成為可能。為了最佳化訓練,在微調 LLaMA 405B 模型,只計算 LoRA 引數的梯度,保持主模型引數不變。這個方法減少了記憶體使用,並加速了訓練,因為只更新較少的引數。可以移步 GitHub 倉庫,檢視實現細節。在訓練過程中,每一步都涉及將一批輸入資料透過模型進行處理。由於只有 LoRA 引數是可訓練的,因此模型的預測和計算的損失僅依賴於這些引數,然後對 LoRA 引數進行反向傳播。只更新這些引數簡化了訓練過程,使得在多個 GPU 上高效微調像 LLaMA 405B 這樣的大型模型成為可能。