DeepSeek用的GRPO佔用大量記憶體?有人給出了些破解方法

机器之心發表於2025-02-07

RTX 3080 移動版能訓練哪種大模型?本文為那些 GPU 資源有限時使用 GRPO 訓練的開發者提供了寶貴的指導。


自 DeepSeek-R1 釋出以來,群組相對策略最佳化(GRPO)因其有效性和易於訓練而成為大型語言模型強化學習的熱門話題。R1 論文展示瞭如何使用 GRPO 從遵循 LLM(DeepSeek-v3)的基本指令轉變為推理模型(DeepSeek-R1)。

GRPO 是一種線上學習演算法(online learning algorithm),它透過使用訓練過程中由訓練模型自身生成的資料來進行迭代改進。GRPO 的目標是最大化生成補全(completions)的優勢函式(advantage),同時確保模型保持在參考策略(reference policy)附近。
圖片
本文的目的是幫你節省一些時間,讓你根據硬體預算選擇合適的模型大小。在開始微調時,你必須做出的重要決定是選擇模型大小,以及你是執行完全微調還是引數高效微調(PEFT)。

文章作者來自 AI 公司 Oxen.ai 的 CEO Greg Schoeninger。
圖片
原文連結:https://www.oxen.ai/blog/grpo-vram-requirements-for-the-gpu-poor

作者表示,他發現 trl 庫中已經有一個易於使用的 GRPO 實現,便立刻開始了訓練,使用的硬體是配備了 16GB 視訊記憶體的 Nvidia GeForce RTX 3080 的小型膝上型電腦。正如大家可能遇到的問題,作者發現示例程式碼中的引數設定導致了一個巨大的視訊記憶體不足(OOM,out of memory )錯誤。

  1. torch.OutOfMemoryError: CUDA out of memory.

  2. Tried to allocate 1.90 GiB. GPU 0 has a total capacity of 15.73 GiB of which 1.28 GiB is free.

  3. Including non-PyTorch memory, this process has 14.43 GiB memory in use. Of the allocated memory 11.82 GiB is allocated by PyTorch, and 2.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


實際使用情況

作者表示,他們進行了一系列實驗,以確定訓練各種大小的模型所需的視訊記憶體(VRAM)要求。引數數量從 5 億到 140 億不等,他們比較了權重的完全微調與引數高效微調(使用 LoRA),所有訓練執行都在英偉達 H100 上完成,因此這裡的 OOM 意味著 >80GB 的 VRAM。
圖片
在表格中,你可以找到 GSM8K 資料集上訓練的前 100 步中的峰值記憶體使用情況。用於實驗的模型是:
圖片
所有實驗均使用 Shadeform 的 GPU 市場完成,因此每次實驗只需要花費幾美元 H100。

實驗結果表明,記憶體需求隨著模型大小和訓練方式的不同而顯著變化。例如,全引數微調比 PEFT 需要更多的記憶體。

為什麼 GRPO 對記憶體需求較高

這要從 GRPO 的原理說起,這是它的流程圖。
圖片
GRPO 對記憶體需求較高的原因在於,其內部涉及多個模型,並且在訓練資料中每個查詢會產生多個輸出。上圖中的策略模型、參考模型和獎勵模型各自都是一個需要進行推理的 LLM。(儘管從技術上講,獎勵模型可能不需要引數化,可以只是一個 Python 函式或正規表示式,但不影響 GRPO 對記憶體的高需求。)

為什麼 8-Bit 最佳化和梯度檢查點有助於減少記憶體佔用?

通常來講,訓練一個大型語言模型需要在記憶體中儲存三種主要型別的資訊:模型引數、模型學習所需的梯度、最佳化器的跟蹤資料。

對上述內容我們可以這樣理解:如果模型的引數佔用了 X 的空間,那麼梯度也會佔用大約相同的空間。然後,像 AdamW 這樣的最佳化器需要更多的空間,因為它們就像一個記錄員,跟蹤最近的更新歷史,以便更好地決定未來的最佳化。

為了減輕這種記憶體負擔,通常採用兩種技術:

  • 首先,可以使用像 AdamW 這樣的 8-bit 最佳化器版本,它們能更高效地儲存跟蹤資料,同時仍保持良好的效能 —— 類似於壓縮照片可以節省空間,同時保留大部分影像質量;
  • 其次,使用梯度檢查點技術,這就像在訓練過程中拍攝快照,而不是記錄所有內容。雖然這會使訓練速度減慢約 20-30%,但它顯著減少了記憶體使用。

結合這些技術,即使對 GPU 資源有限的人來說,也能夠訓練更大的模型。

程式碼示例

像 trl 這樣的庫已經開始支援 GRPO,使得微調由 transformers 構成的 LLM 變得非常簡單。程式碼也非常簡潔,只需將訓練器替換為 GRPOTrainer 並定義一些獎勵即可。GRPO 的最小程式碼量大約只有 99 行,如果你使用的是像 meta-llama/Llama-3.2-1B-Instruct 這樣的小型模型和像 openai/GSM8K 這樣的資料集,可以非常快速地啟動。

trl 專案地址:https://github.com/huggingface/trl?ref=ghost.oxen.ai

  1. import torch

  2. from datasets import load_dataset, Dataset

  3. from transformers import AutoTokenizer, AutoModelForCausalLM

  4. from trl import GRPOConfig, GRPOTrainer

  5. import re

  6. SYSTEM_PROMPT = """

  7. Respond in the following format:

  8. <reasoning>

  9. ...

  10. </reasoning>

  11. <answer>

  12. ...

  13. </answer>

  14. """

  15. def extract_hash_answer(text: str) -> str | None:

  16. if "####" not in text:

  17. return None

  18. return text.split("####")[1].strip()

  19. def get_gsm8k_questions(split = "train") -> Dataset:

  20. data = load_dataset('openai/gsm8k', 'main')[split]

  21. data = data.map(lambda x: {

  22. 'prompt': [

  23. {'role': 'system', 'content': SYSTEM_PROMPT},

  24. {'role': 'user', 'content': x['question']}

  25. ],

  26. 'answer': extract_hash_answer(x['answer'])

  27. })

  28. return data

  29. def extract_xml_answer(text: str) -> str:

  30. answer = text.split("<answer>")[-1]

  31. answer = answer.split("</answer>")[0]

  32. return answer.strip()

  33. def format_reward_func(completions, **kwargs) -> list[float]:

  34. """Reward function that checks if the completion has a specific format."""

  35. pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"

  36. responses = [completion[0]["content"] for completion in completions]

  37. matches = [re.match(pattern, r) for r in responses]

  38. return [0.5 if match else 0.0 for match in matches]

  39. def accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:

  40. """Reward function that extracts the answer from the xml tags and compares it to the correct answer."""

  41. responses = [completion[0]['content'] for completion in completions]

  42. extracted_responses = [extract_xml_answer(r) for r in responses]

  43. return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

  44. def main():

  45. dataset = get_gsm8k_questions()

  46. model_name = "meta-llama/Llama-3.2-1B-Instruct"

  47. model = AutoModelForCausalLM.from_pretrained(

  48. model_name,

  49. torch_dtype=torch.bfloat16,

  50. attn_implementation="flash_attention_2",

  51. device_map=None

  52. ).to("cuda")

  53. tokenizer = AutoTokenizer.from_pretrained(model_name)

  54. tokenizer.pad_token = tokenizer.eos_token

  55. training_args = GRPOConfig(

  56. output_dir="output",

  57. learning_rate=5e-6,

  58. adam_beta1=0.9,

  59. adam_beta2=0.99,

  60. weight_decay=0.1,

  61. warmup_ratio=0.1,

  62. lr_scheduler_type='cosine',

  63. logging_steps=1,

  64. bf16=True,

  65. per_device_train_batch_size=1,

  66. gradient_accumulation_steps=4,

  67. num_generations=4,

  68. max_prompt_length=256,

  69. max_completion_length=786,

  70. num_train_epochs=1,

  71. save_steps=100,

  72. save_total_limit=1,

  73. max_grad_norm=0.1,

  74. log_on_each_node=False,

  75. )

  76. trainer = GRPOTrainer(

  77. model=model,

  78. processing_class=tokenizer,

  79. reward_funcs=[

  80. format_reward_func,

  81. accuracy_reward_func

  82. ],

  83. args=training_args,

  84. train_dataset=dataset,

  85. )

  86. trainer.train()

  87. if __name__ == "__main__":

  88. main()


Num Generations 有什麼用

Num Generations 是一個超引數,它決定了我們將在訓練資料中對每個查詢取樣多少個補全。然而,這會顯著增加 VRAM 的消耗。
圖片
目前有一個開放的 GitHub 問題,可能會幫助解決記憶體瓶頸問題,可以參考如下連結

地址:https://github.com/huggingface/trl/issues/2709?ref=ghost.oxen.ai

對於 num_completions=8,16,64 (DeepSeekMath 論文使用的 64),作者表示,不用再次計算上述所有值,而是使用了 1B 引數模型進行了測試,以顯示記憶體增長。不過,作者還是建議大家在記憶體瓶頸得到修復之前使用 num_generations=4,也能獲得不錯的效能。
圖片
影響 VRAM 的一些因素

要對所有影響視訊記憶體(VRAM)使用的因素進行全面的超引數驗證,需要進行大量的實驗。簡單起見,這裡只指出了需要注意的設定,以及實驗中使用的具體數值。

  • batch_size=1,由於 GRPO 為每個查詢生成多個響應,batch size 會迅速失控。
  • gradient_accumulation_steps=4,最佳化器是另一個佔用大量 VRAM 的地方。此引數決定了我們將儲存的梯度以幫助最佳化器進行其「爬山」過程。
  • num_completions=4,DeepSeekMath 論文中使用了 64。這完全超出了有些人的計算預算。
  • max_prompt_length=256,如果你想訓練模型擁有更大上下文的推理能力,將不得不增加 VRAM。GSM8K 的提示相對較小,適合此測試。
  • max_completion_length=786,同樣,由於計算注意力的記憶體有限,推理鏈在這裡受到限制。上下文或生成的 token 越多,需要的記憶體就越大。
  • LoRA target_modules=["q_proj", "k_proj", "o_proj", "up_proj", "down_proj"] 在這方面可以嘗試幾種不同的迭代。target_modules="all-linear" 是一種流行的方式,可以從你的 LoRA 中擠出最多的效能(就準確性而言)。

對 VRAM 使用的粗略估算

如果你正在使用 FP16 精度進行訓練,以下是一些簡單的估算方法,可以幫助你瞭解記憶體主要用在了哪些地方:

  • 模型引數:每個引數佔用 2 位元組。
  • 參考模型引數:每個引數佔用 2 位元組。
  • 梯度:每個引數佔用 2 位元組。
  • 最佳化器狀態:每個引數佔用 8 位元組。
  • 8 位最佳化器:每個引數佔用 4 位元組。
  • PEFT:有助於減少梯度的視訊記憶體佔用。

最後是關於準確率的。作者完成了一個 10 億引數的 Llama 3.2 模型的完整訓練。在應用 GRPO 之前,該模型在保留測試集上達到了約 19% 的準確率,而在經過一個訓練週期後,模型的準確率飆升至約 40.5%。雖然這離 SOTA 水平還差得很遠,但這展示了 GRPO 的強大潛力。

相關文章