langchain chatchat執行機制原始碼解析

郑某發表於2024-03-19

langchain chatchat的簡介就不多說了,大家可以去看github官網介紹,雖然當前版本停止了更新,下個版本還沒有出來,但作為學習還是很好的。

一、關鍵啟動過程:

1、start_main_server 入口

2、run_controller 啟動fastchat controller 埠20001

3、run_openai_api啟動fastchat對外提供的類似openai介面的服務,埠20000

4、run_model_worker 建立fastchat的model_worker,其中又執行了以下過程:

4.1、create_model_worker_app,根據配置檔案,建立並初始化對應的model_workder,初始化過程中,model_worker會透過self.init_heart_beat()將自己註冊到fastchat controller中,以供fastchat管理呼叫。最後create_model_worker_app方法取出model_workder的fastaip物件app,將app返回。

4.2 、uvicorn.run(app, host=host, port=port, log_level=log_level.lower()),啟動模型對應的model_workder服務,這裡的app來自model_workder的app。

二、chat過程

1、app.post("/chat/chat",
tags=["Chat"],
summary="與llm模型對話(透過LLMChain)",
)(chat)
2、本地模型LLM對話
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=callbacks,
)
get_ChatOpenAI:
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
**kwargs
)
在這裡指定了fastchat的openai_api介面地址,這樣就獲得了指定介面地址的langchain ChatOpenAI物件
然後建立LLMChain:
chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
後面省略
3、線上模型LLM對話
線上模型的呼叫並沒有直接發起,還是和上面一樣,透過獲取ChatOpenAI物件,來和fastchat進行互動,但是fastchat是不支援自定義呼叫線上模型的,langchain chatchat是怎麼實現的呢?
原來,對應線上模型呼叫,langchain chatchat還是透過類似建立本地模型一樣建立model_worker,但是對model_worker進行了繼承,互動部分進行了重寫,如qwen線上呼叫:
class QwenWorker(ApiModelWorker):
而ApiModelWorker來自BaseModelWorker,BaseModelWorker就是fastchat的worker_model的基類。(本地模型例項化時用的ModelWorker本身也是繼承自BaseModelWorker)
class ApiModelWorker(BaseModelWorker):
    DEFAULT_EMBED_MODEL: str = None # None means not support embedding

    def __init__(
        self,
        model_names: List[str],
        controller_addr: str = None,
        worker_addr: str = None,
        context_len: int = 2048,
        no_register: bool = False,
        **kwargs,
    ):
        kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
        kwargs.setdefault("model_path", "")
        kwargs.setdefault("limit_worker_concurrency", 5)
        super().__init__(model_names=model_names,
                        controller_addr=controller_addr,
                        worker_addr=worker_addr,
                        **kwargs)
        import fastchat.serve.base_model_worker
        import sys
        self.logger = fastchat.serve.base_model_worker.logger
        # 恢復被fastchat覆蓋的標準輸出
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__

        new_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(new_loop)

        self.context_len = context_len
        self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency)
        self.version = None

        if not no_register and self.controller_addr:
            self.init_heart_beat()


    def count_token(self, params):
        prompt = params["prompt"]
        return {"count": len(str(prompt)), "error_code": 0}

    def generate_stream_gate(self, params: Dict):
        self.call_ct += 1

        try:
            prompt = params["prompt"]
            if self._is_chat(prompt):
                messages = self.prompt_to_messages(prompt)
                messages = self.validate_messages(messages)
            else: # 使用chat模仿續寫功能,不支援歷史訊息
                messages = [{"role": self.user_role, "content": f"please continue writing from here: {prompt}"}]

            p = ApiChatParams(
                messages=messages,
                temperature=params.get("temperature"),
                top_p=params.get("top_p"),
                max_tokens=params.get("max_new_tokens"),
                version=self.version,
            )
            for resp in self.do_chat(p):
                yield self._jsonify(resp)
        except Exception as e:
            yield self._jsonify({"error_code": 500, "text": f"{self.model_names[0]}請求API時發生錯誤:{e}"})

    def generate_gate(self, params):
        try:
            for x in self.generate_stream_gate(params):
                ...
            return json.loads(x[:-1].decode())
        except Exception as e:
            return {"error_code": 500, "text": str(e)}


    # 需要使用者自定義的方法

    def(self, params: ApiChatParams) -> Dict:
        '''
        執行Chat的方法,預設使用模組裡面的chat函式。
        要求返回形式:{"error_code": int, "text": str}
        '''
        return {"error_code": 500, "text": f"{self.model_names[0]}未實現chat功能"}

    # def do_completion(self, p: ApiCompletionParams) -> Dict:
    #     '''
    #     執行Completion的方法,預設使用模組裡面的completion函式。
    #     要求返回形式:{"error_code": int, "text": str}
    #     '''
    #     return {"error_code": 500, "text": f"{self.model_names[0]}未實現completion功能"}

    def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
        '''
        執行Embeddings的方法,預設使用模組裡面的embed_documents函式。
        要求返回形式:{"code": int, "data": List[List[float]], "msg": str}
        '''
        return {"code": 500, "msg": f"{self.model_names[0]}未實現embeddings功能"}

    def get_embeddings(self, params):
        # fastchat對LLM做Embeddings限制很大,似乎只能使用openai的。
        # 在前端透過OpenAIEmbeddings發起的請求直接出錯,無法請求過來。
        print("get_embedding")
        print(params)

    def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
        raise NotImplementedError

    def validate_messages(self, messages: List[Dict]) -> List[Dict]:
        '''
        有些API對mesages有特殊格式,可以重寫該函式替換預設的messages。
        之所以跟prompt_to_messages分開,是因為他們應用場景不同、引數不同
        '''
        return messages


    # help methods
    @property
    def user_role(self):
        return self.conv.roles[0]

    @property
    def ai_role(self):
        return self.conv.roles[1]

    def _jsonify(self, data: Dict) -> str:
        '''
        將chat函式返回的結果按照fastchat openai-api-server的格式返回
        '''
        return json.dumps(data, ensure_ascii=False).encode() + b"\0"

    def _is_chat(self, prompt: str) -> bool:
        '''
        檢查prompt是否由chat messages拼接而來
        TODO: 存在誤判的可能,也許從fastchat直接傳入原始messages是更好的做法
        '''
        key = f"{self.conv.sep}{self.user_role}:"
        return key in prompt

    def prompt_to_messages(self, prompt: str) -> List[Dict]:
        '''
        將prompt字串拆分成messages.
        '''
        result = []
        user_role = self.user_role
        ai_role = self.ai_role
        user_start = user_role + ":"
        ai_start = ai_role + ":"
        for msg in prompt.split(self.conv.sep)[1:-1]:
            if msg.startswith(user_start):
                if content := msg[len(user_start):].strip():
                    result.append({"role": user_role, "content": content})
            elif msg.startswith(ai_start):
                if content := msg[len(ai_start):].strip():
                    result.append({"role": ai_role, "content": content})
            else:
                raise RuntimeError(f"unknown role in msg: {msg}")
        return result

    @classmethod
    def can_embedding(cls):
        return cls.DEFAULT_EMBED_MODEL is not None

  從程式碼中可以看到ApiModelWorker重寫了generate_stream_gate,並且呼叫了do_chat方法,該方法要求子類去實現實際的chat過程。我們再回到class QwenWorker(ApiModelWorker):

def do_chat(self, params: ApiChatParams) -> Dict:
        import dashscope
        params.load_config(self.model_names[0])
        if log_verbose:
            logger.info(f'{self.__class__.__name__}:params: {params}')

        gen = dashscope.Generation()
        responses = gen.call(
            model=params.version,
            temperature=params.temperature,
            api_key=params.api_key,
            messages=params.messages,
            result_format='message',  # set the result is message format.
            stream=True,
        )

        for resp in responses:
            if resp["status_code"] == 200:
                if choices := resp["output"]["choices"]:
                    yield {
                        "error_code": 0,
                        "text": choices[0]["message"]["content"],
                    }
            else:
                data = {
                    "error_code": resp["status_code"],
                    "text": resp["message"],
                    "error": {
                        "message": resp["message"],
                        "type": "invalid_request_error",
                        "param": None,
                        "code": None,
                    }
                }
                self.logger.error(f"請求千問 API 時發生錯誤:{data}")
                yield data

  至此,qwen線上模型完成了呼叫。

三、總結

不得不說,這種設計還是很精妙的,藉助fastchat,不僅實現了fastchat支援的幾個本地大模型的呼叫,對於線上模型,即使不同的線上模型有不同的api介面定義,但只需要去定義實現一個新的繼承ApiModelWorker的類,就可以遮蔽掉介面之間的差異,透過fastchat對齊介面,統一對外提供類openai api介面服務,這樣在langchain不做修改的情況下,langchain就可以正常呼叫市面上各類介面迥異的線上大模型。

三、後續計劃

1、嘗試langchain chatchat和ollama的對接
2、Agent應用實踐

我建了一個langchain交流群,歡迎加入一起交流學習心得:

相關文章