本文基於https://github.com/datawhalechina/self-llm/blob/master/GLM-4/02-GLM-4-9B-chat%20langchain%20%E6%8E%A5%E5%85%A5.md提供的教程。由於使用本地部署的大模型,在繼承LangChain中的LLM類時需要重寫幾個函式。
但是在具體測試的時候出現了以下的錯誤
/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py:1659: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.
warnings.warn(
Traceback (most recent call last):
File "/root/autodl-tmp/glm4LLM.py", line 63, in <module>
print(llm.invoke("你是誰"))
^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 276, in invoke
self.generate_prompt(
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 633, in generate_prompt
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 803, in generate
output = self._generate_helper(
^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 670, in _generate_helper
raise e
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 657, in _generate_helper
self._generate(
File "/root/miniconda3/lib/python3.12/site-packages/langchain_core/language_models/llms.py", line 1322, in _generate
self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
File "/root/autodl-tmp/glm4LLM.py", line 40, in _call
generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 1758, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/transformers/generation/utils.py", line 2397, in _sample
outputs = self(
^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 1005, in forward
transformer_outputs = self.transformer(
^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 887, in forward
inputs_embeds = self.embedding(input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 823, in forward
words_embeddings = self.word_embeddings(input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/sparse.py", line 163, in forward
return F.embedding(
^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/functional.py", line 2264, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
錯誤原因主要是因為input_ids(輸入資料)與model(模型)所在裝置不一致。
經過修改成下面的程式碼可以成功執行,主要修改了input_ids對應語句。
from langchain.llms.base import LLM
from typing import Any, List, Optional, Dict
from langchain.callbacks.manager import CallbackManagerForLLMRun
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class ChatGLM4_LLM(LLM):
# 基於本地 ChatGLM4 自定義 LLM 類
tokenizer: AutoTokenizer = None
model: AutoModelForCausalLM = None
gen_kwargs: dict = None
def __init__(self, model_name_or_path: str, gen_kwargs: dict = None):
super().__init__()
print("正在從本地載入模型...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
).eval()
print("完成本地模型的載入")
if gen_kwargs is None:
gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
self.gen_kwargs = gen_kwargs
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any) -> str:
messages = [{"role": "user", "content": prompt}]
model_inputs = self.tokenizer.apply_chat_template(
messages, tokenize=True, return_tensors="pt", return_dict=True, add_generation_prompt=True
)
# 將input_ids移動到與模型相同的裝置
device = next(self.model.parameters()).device
model_inputs = {key: value.to(device) for key, value in model_inputs.items()}
generated_ids = self.model.generate(**model_inputs, **self.gen_kwargs)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs['input_ids'], generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
@property
def _identifying_params(self) -> Dict[str, Any]:
"""返回用於識別LLM的字典,這對於快取和跟蹤目的至關重要。"""
return {
"model_name": "glm-4-9b-chat",
"max_length": self.gen_kwargs.get("max_length"),
"do_sample": self.gen_kwargs.get("do_sample"),
"top_k": self.gen_kwargs.get("top_k"),
}
@property
def _llm_type(self) -> str:
return "glm-4-9b-chat"