關於BGE-M3接入LangChain時遇到的問題與解決方法

YTARO發表於2024-07-05

本文基於https://github.com/datawhalechina/self-llm/blob/master/GLM-4/02-GLM-4-9B-chat%20langchain%20%E6%8E%A5%E5%85%A5.md提供的教程。由於使用本地部署的大模型,在繼承LangChain中的LLM類時需要重寫幾個函式。

但是在具體測試的時候出現了以下的錯誤

/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py:1659: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.
  warnings.warn(
Traceback (most recent call last):
  File "/root/autodl-tmp/glm4LLM.py", line 63, in <module>
    print(llm.invoke("你是誰"))
          ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 276, in invoke
    self.generate_prompt(
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 633, in generate_prompt
    return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 803, in generate
    output = self._generate_helper(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 670, in _generate_helper
    raise e
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 657, in _generate_helper
    self._generate(
  File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 1322, in _generate
    self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
  File "/root/autodl-tmp/glm4LLM.py", line 40, in _call
    generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 1758, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2397, in _sample
    outputs = self(
              ^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 1005, in forward
    transformer_outputs = self.transformer(
                          ^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 887, in forward
    inputs_embeds = self.embedding(input_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 823, in forward
    words_embeddings = self.word_embeddings(input_ids)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/functional.py", line 2264, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

錯誤原因主要是因為input_ids(輸入資料)與model(模型)所在裝置不一致。

經過修改成下面的程式碼可以成功執行,主要修改了input_ids對應語句。

from langchain.llms.base import LLM
from typing import Any, List, Optional, Dict
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

class ChatGLM4_LLM(LLM):
    # 基於本地 ChatGLM4 自定義 LLM 類
    tokenizer: AutoTokenizer = None
    model: AutoModelForCausalLM = None
    gen_kwargs: dict = None
        
    def __init__(self, model_name_or_path: str, gen_kwargs: dict = None):
        super().__init__()
        print("正在從本地載入模型...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        ).eval()
        print("完成本地模型的載入")
        
        if gen_kwargs is None:
            gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
        self.gen_kwargs = gen_kwargs
        
    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None,
              **kwargs: Any) -> str:
        messages = [{"role": "user", "content": prompt}]
        model_inputs = self.tokenizer.apply_chat_template(
            messages, tokenize=True, return_tensors="pt", return_dict=True, add_generation_prompt=True
        )
        
        # 將input_ids移動到與模型相同的裝置
        device = next(self.model.parameters()).device
        model_inputs = {key: value.to(device) for key, value in model_inputs.items()}
        
        generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs['input_ids'], generated_ids)
        ]
        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return response
    
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """返回用於識別LLM的字典,這對於快取和跟蹤目的至關重要。"""
        return {
            "model_name": "glm-4-9b-chat",
            "max_length": self.gen_kwargs.get("max_length"),
            "do_sample": self.gen_kwargs.get("do_sample"),
            "top_k": self.gen_kwargs.get("top_k"),
        }

    @property
    def _llm_type(self) -> str:
        return "glm-4-9b-chat"

相關文章