嘗試上手大模型的有監督微調

倒地發表於2024-04-27

總覽

這個文章留下微調 Gemma-2b-it 模型的記錄。以很簡單的、只有一句話的微調為例。

本文最終的目標:問模型 “微調測試”,模型回答 “我學會了”。

準備

載入模型和分詞器

tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-2b-it",
    cache_dir="/mnt/g/models/cache",
    token=token,
)

model = AutoModelForCausalLM.from_pretrained(
    "/mnt/g/models/gemma-1.1-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

訓練集

對話訓練集格式通常是 json,網上下載的字尾一般為 .jsonl,長這樣(來源):

{"conversation": [{"role": "human", "message": "中國移動營銷行來發展報告alink"}, {"role": "assistant", "message": "王大姐,打字細心一點"}], "category": null, "data_source": "weibo"}
{"conversation": [{"role": "human", "message": "小馬也瘋狂------地位之爭。"}, {"role": "assistant", "message": "於老師不給勸勸架麼告訴他們再掙也不是老大"}], "category": null, "data_source": "weibo"}
{"conversation": [{"role": "human", "message": "那些年,我們一起偷看過的電視。「暴走漫畫」"}, {"role": "assistant", "message": "真不愧是這麼走出來的爹·······"}], "category": null, "data_source": "weibo"}
···

使用 HuggingFace 的 dataset 庫可載入這個檔案。

from datasets import load_dataset

dataset_fromfile = load_dataset(
    "json",
    data_files=r"./test_input_dataset.jsonl",
    split="train",
)

本文只用一個對話的微調進行示例,所以用以下方法準備訓練集 chat

chat = {
    "conversation": [
        [
            {
                "role": "user",
                "content": "微調測試",
            },
            {
                "role": "assistant",
                "content": "我學會了",
            },
        ],
    ]
}
chat = Dataset.from_dict(chat)

TRL 庫

藉助 HuggingFace 的 TRL 庫進行微調。

pip install trl

接下來介紹 TRL 的兩個類,以及一個特殊的函式 formatting_func

SFTTrainer

SFT 是 “有監督微調” 的縮寫(Supervised Finetuning)。

SFTTrainer 繼承於 transformers.Trainer。藉助 SFTTrainer,可以封裝一個專用於語言模型有監督微調的類。

DataCollatorForCompletionOnlyLM

藉助 DataCollatorForCompletionOnlyLM,可以僅對需要生成的 prompt 訓練。即,只對模型生成的 token 部分計算 loss。

其他細節不必深究,只需要知道 SFTTrainer 需要一個 data_collator 物件,將語料轉換成適合訓練的形式。

response_template = "<start_of_turn>model\n"

collator = DataCollatorForCompletionOnlyLM(
    tokenizer=tokenizer,
    response_template=response_template,
)

可見,例項化這個 collator 需要傳入 tokenizerresponse_template

在 Gemma 中,模型的回答都接在 "<start_of_turn>model\n" 之後,所以傳入這個 response_template 告訴 collator 從這裡開始標記需要訓練的部分。

formatting_func

語料需要先轉換成某種字串,再轉換成 token,才能輸入到模型。

為了將訓練語料正確處理成符合預訓練模型規則的字串,SFTTrainer 需要傳入一個處理函式。

def formatting_prompts_func(example):
    output_texts = []
    for c in example["conversation"]:
        text = tokenizer.apply_chat_template(c, tokenize=False) + tokenizer.eos_token
        output_texts.append(text)
    return output_texts

這裡取了巧,藉助 tokenizer 自帶的 chat_template 轉換。

TrainingArguments

需要向 SFTTrainer 傳入最佳化器、學習率等引數。

不必多言,看示例程式碼。更多可選引數請查閱 HuggingFace 文件。

from transformers import TrainingArguments

args = TrainingArguments(
        per_device_train_batch_size=8,
        num_train_epochs=30,
        learning_rate=2e-5,
        optim="adamw_8bit",
        bf16=True,
        output_dir="/mnt/z/model_test",
        report_to=["tensorboard"],
        logging_steps=1,
)

開始訓練

做好一切準備後,就能例項化 SFTrainer 開始訓練了。

trainer = SFTTrainer(
    model,
    tokenizer=tokenizer,
    train_dataset=chat,
    max_seq_length=1024,
    args=args,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    dataset_kwargs={"add_special_tokens": False},  # 特殊 token 已經在 formatting_func 加過了
)
trainer.train()

LoRA

藉助 peft 庫,只需要封裝一遍 model 就能應用 LoRA。

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=512,
    lora_alpha=512,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
)

model = get_peft_model(model, lora_config)
# model.print_trainable_parameters()

接下來向 SFTTrainer 傳入這個 model 就行。

測試

我用這段程式碼測試訓練效果:

chat = [
    {
        "role": "user",
        "content": "微調測試",
    },
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_length=100)
print(tokenizer.decode(outputs[0]))

可以看到效果很明顯。

<bos><start_of_turn>user
微調測試<end_of_turn>
<start_of_turn>model
我學會了<end_of_turn>
<eos>

相關文章