ChatGLM-6B本地部署和P-Tuning微調嘗試

yicheng_liu0219發表於2024-05-07

to 2024 / 04 / 22

部署環境

OS: Windows10, WSL2 ( Ubuntu 20.04 )

CPU: Intel(R) Core(TM) i5-12490F

GPU: GeForce RTX 4070Ti

部署過程

部署主要參考$[2]$,其中也遇到了一定的問題,記錄如下:

模型下載

模型需要使用Git LFS工具進行下載,由於之前在Windows環境下已經下載過模型檔案,且檔案較大,直接在系統內進行復制而沒有重複下載(具體可以參考$[3]$)。WindowsPowerShell下載指令:

git clone https://huggingface.co/THUDM/chatglm-6b

需要將如下對應檔案複製到WSL2自己設定的檔案路徑下:

環境配置

使用conda (4.5.11) 建立環境,pip (23.3.1)配置環境,可以嘗試直接在git的專案$[1]$路徑下執行:

pip install -r requirements.txt

最開始下載時存在部分模組(e.g. PyYAML)版本不一致問題,可能是conda最開始初始化時導致的,如果按照所需的環境逐個下載,可以嘗試使用以下指令強行更新版本(但是無法刪除,參考$[4]$):

pip3 install --ignore-installed PyYAML

在之後執行模型時,可能遇到 'Textbox' object has no attribute 'style' 報錯,可能是gradio模組版本過高導致的,可以嘗試單獨指定gradio版本(參考$[5]$):

pip uninstall gradio
pip install gradio==3.50.0

DEMO & API 嘗試

專案本身提供了web和cli兩個demo,但個人在使用web demo載入時會出現問題,考慮到專案有自己單獨的前端,所以該問題未解決,cli demo可以正常執行,需要修改cli_demo.py中的部分內容:

LOCAL_PATH = "/home/lyc/workspace/ChatGLM-6B"
tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH+"/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained(LOCAL_PATH+"/chatglm-6b", trust_remote_code=True).half().cuda()

需要注意的是LOCAL_PATH需要是絕對路徑。在視訊記憶體不足時可以進行量化:

# 按需修改,目前只支援 4/8 bit 量化
model = AutoModel.from_pretrained("THUDM/chatglm-6b", 	trust_remote_code=True).quantize(8).half().cuda()

命令列執行結果如下:

基於 P-Tuning 微調 ChatGLM-6B

安裝依賴,且需要確保transformers模組版本為4.27.1,嘗試執行如下程式碼:

pip install rouge_chinese nltk jieba datasets
export WANDB_DISABLED=true

在最開始git的專案中,your_path/ChatGLM-6B/ptuning路徑下提供了P-Tuning的demo,需要修改如下內容:

其中藍框內的cli_demo.py是因為自帶的web_demo我無法執行,簡單修改了最開始目錄下的內容來執行經過微調後的模型的,cli_demo.sh用於啟動cli_demo.py,兩者內容如下:

# cli_demo.py
import os, sys
import platform
import signal

import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)

from arguments import ModelArguments, DataTrainingArguments

import readline

# LOCAL_PATH = "/home/lyc/workspace/ChatGLM-6B"

# tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH+"/chatglm-6b", trust_remote_code=True)
# model = AutoModel.from_pretrained(LOCAL_PATH+"/chatglm-6b", trust_remote_code=True).half().cuda()
# model = model.eval()

model = None
tokenizer = None

os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False


def build_prompt(history):
    prompt = "歡迎使用 ChatGLM-6B 模型,輸入內容即可進行對話,clear 清空對話歷史,stop 終止程式"
    for query, response in history:
        prompt += f"\n\n使用者:{query}"
        prompt += f"\n\nChatGLM-6B:{response}"
    return prompt


def signal_handler(signal, frame):
    global stop_stream
    stop_stream = True


def main():
    global model, tokenizer

    parser = HfArgumentParser((
        ModelArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
    else:
        model_args = parser.parse_args_into_dataclasses()[0]

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=True)
    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=True)

    config.pre_seq_len = model_args.pre_seq_len
    config.prefix_projection = model_args.prefix_projection

    if model_args.ptuning_checkpoint is not None:
        print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}")
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
        prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    else:
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)

    if model_args.quantization_bit is not None:
        print(f"Quantized to {model_args.quantization_bit} bit")
        model = model.quantize(model_args.quantization_bit)

    if model_args.pre_seq_len is not None:
        # P-tuning v2
        model = model.half().cuda()
        model.transformer.prefix_encoder.float().cuda()
    
    model = model.eval()

    history = []
    global stop_stream
    print("歡迎使用 ChatGLM-6B 模型,輸入內容即可進行對話,clear 清空對話歷史,stop 終止程式")
    while True:
        query = input("\n使用者:")
        if query.strip() == "stop":
            break
        if query.strip() == "clear":
            history = []
            os.system(clear_command)
            print("歡迎使用 ChatGLM-6B 模型,輸入內容即可進行對話,clear 清空對話歷史,stop 終止程式")
            continue
        count = 0
        for response, history in model.stream_chat(tokenizer, query, history=history):
            if stop_stream:
                stop_stream = False
                break
            else:
                count += 1
                if count % 8 == 0:
                    os.system(clear_command)
                    print(build_prompt(history), flush=True)
                    signal.signal(signal.SIGINT, signal_handler)
        os.system(clear_command)
        print(build_prompt(history), flush=True)


if __name__ == "__main__":
    main()

在cli_demo.sh中,model_name_or_path需要改為你最開始下載模型的位置,ptuning_checkpoint需要與train.sh中的內容相對應,不同的訓練模型會儲存在不同地方。

PRE_SEQ_LEN=32

CUDA_VISIBLE_DEVICES=0 python3 cli_demo.py \
    --model_name_or_path /home/lyc/workspace/ChatGLM-6B/chatglm-6b \
    --ptuning_checkpoint output/adgen-chatglm-6b-pt-32-2e-2/checkpoint-500 \
    --pre_seq_len $PRE_SEQ_LEN

橘框中為測試資料和訓練資料,以json格式進行儲存,形如:

[
	{"content": "xxx1", "summary": "yyy1"},
	{"content": "xxx2", "summary": "yyy2"},
	...
	{"content": "xxx3", "summary": "yyy3"}
]

紅框為訓練和測試的指令碼,可以參考$[2]$按需修改對應引數 。

其他問題

部分模組或模型下載可能需要代理,個人使用clash代理,WSL2中需要配置git和conda的代理,git可以參考$[6]$,conda可以在使用者目錄下修改 .condarc檔案,增添內容:

proxy_servers:
  http: http://nameserver:port
  https: https://nameserver:port
ssl_verify: false

其中nameserver可以在路徑 /etc/resolv.conf中檢視,port請參考clash中的設定,預設為7890。

後續(本科專案實訓)

在測試中,使用 5 條資料訓練 500 epoch,損失函式基本收斂,驗證準確率較高,但距離目標任務的實際使用還有一定的距離,面對不同的輸入格式的魯棒性不足,需要設計輸出函式格式並自動生成更多的訓練測試資料。

本地部署算力較為吃緊,可能需要在伺服器上進行微調。模型API需要進一步熟悉,以方便後續的專案開發。

參考資料

[1] ChatGLM-6B: An Open Bilingual Dialogue Language Model | 開源雙語對話語言模型
https://github.com/THUDM/ChatGLM-6B
[2] chatglm的微調有沒有保姆式的教程?? - 樹先生的回答 - 知乎
https://www.zhihu.com/question/595670355/answer/3015099216
[3] 安裝 Git Large File Storage
https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
[4] [已解決] Cannot uninstall ‘PyYAML’.
https://clay-atlas.com/blog/2022/04/08/cannot-uninstall-pyyaml-distutils-installed-project/#google_vignette
[5] chatglm2-b部署報錯問題‘Textbox‘ object has no attribute ‘style‘
https://blog.csdn.net/m0_54393918/article/details/134355019
[6] WSL2 訪問 Clash 網路代理
https://jike.dev/posts/wsl2-access-clash-network-proxy/

相關文章