構建RAG應用-day04-將LLM 接入 LangChain 構建檢索問答鏈 部署知識庫助手

passion2021發表於2024-04-25

llm接入langchain

示例-llm作為翻譯助手

接入chatgpt到langchain,使用普通寫法和Langchain的表示式語言寫法 (LCEL)。

import os
import openai
from dotenv import load_dotenv, find_dotenv
from langchain_openai import ChatOpenAI
from langchain.prompts.chat import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser


def init_openai():
    openai.base_url = 'https://api.chatanywhere.tech/v1'
    _ = load_dotenv(find_dotenv())
    openai_api_key = os.environ['OPENAI_API_KEY']
    llm = ChatOpenAI(temperature=0.0, base_url=openai.base_url, openai_api_key=openai_api_key, streaming=True)
    return llm


def init_prompt():
    template = "你是一個翻譯助手,可以幫助我將 {input_language} 翻譯成 {output_language}."
    human_template = "{text}"

    chat_prompt = ChatPromptTemplate.from_messages([
        ("system", template),
        ("human", human_template),
    ])

    return chat_prompt


if __name__ == '__main__':
    # # 普通寫法
    # # 初始化llm
    # llm = init_openai()
    # # 初始化prompt
    # chat_prompt = init_prompt()
    # text = "我帶著比身體重的行李,\
    #   遊入尼羅河底,\
    #   經過幾道閃電 看到一堆光圈,\
    #   不確定是不是這裡。\
    #   "
    # messages = chat_prompt.format_messages(input_language="中文", output_language="英文", text=text)
    # # 呼叫llm
    # output = llm.invoke(messages)
    # # 輸出解析
    # output_parser = StrOutputParser()
    # result = output_parser.invoke(output)
    # print(result)

    # Langchain的表示式語言寫法 (LCEL)
    llm = init_openai()
    chat_prompt = init_prompt()
    output_parser = StrOutputParser()
    # 類似於 Unix 管道運算子,它將不同的元件連結在一起,將一個元件的輸出作為下一個元件的輸入。
    chain = chat_prompt | llm | output_parser
    chain.invoke({
        "input_language": "中文",
        "output_language": "英文",
        "text": "我帶著比身體重的行李,\
        遊入尼羅河底,\
        經過幾道閃電 看到一堆光圈,\
        不確定是不是這裡。\
      "})

構建問答鏈

使用langchain+chatgpt+zhipu embedding 構建應用,並且使用langchain記憶功能。

from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import openai
from langchain_openai import ChatOpenAI
from embed import ZhipuAIEmbeddings
import os
from dotenv import load_dotenv, find_dotenv
from langchain.vectorstores.chroma import Chroma

from langchain.memory import ConversationBufferMemory


class VectorDB:
    embedding = ZhipuAIEmbeddings()
    persist_directory = 'data_base/vector_db/chroma'
    slice = 20

    def __init__(self, sliced_docs: list = None):
        if sliced_docs is None:
            self.vectordb = Chroma(embedding_function=self.embedding, persist_directory=self.persist_directory)
        else:
            self.vectordb = Chroma.from_documents(
                documents=sliced_docs[:self.slice],  # 為了速度,只選擇前 20 個切分的 doc 進行生成;使用千帆時因QPS限制,建議選擇前 5 個doc
                embedding=self.embedding,
                persist_directory=self.persist_directory  # 允許我們將persist_directory目錄儲存到磁碟上
            )

    def persist(self):
        self.vectordb.persist()
        print(f"向量庫中儲存的數量:{self.vectordb._collection.count()}")

    def sim_search(self, query, k=3):
        sim_docs = self.vectordb.similarity_search(query, k=k)
        for i, sim_doc in enumerate(sim_docs, start=1):
            print(f"檢索到的第{i}個內容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
        return sim_docs

    def mmr_search(self, query, k=3):
        mmr_docs = self.vectordb.max_marginal_relevance_search(query, k=k)
        for i, sim_doc in enumerate(mmr_docs, start=1):
            print(f"MMR 檢索到的第{i}個內容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
        return mmr_docs


def init_openai():
    openai.base_url = 'https://api.chatanywhere.tech/v1'
    _ = load_dotenv(find_dotenv())
    openai_api_key = os.environ['OPENAI_API_KEY']
    llm = ChatOpenAI(temperature=0.0, base_url=openai.base_url, openai_api_key=openai_api_key, streaming=True)
    return llm


if __name__ == '__main__':
    llm = init_openai()
    vectordb = VectorDB().vectordb
    memory = ConversationBufferMemory(
        memory_key="chat_history",  # 與 prompt 的輸入變數保持一致。
        return_messages=True,  # 將以訊息列表的形式返回聊天記錄,而不是單個字串
        output_key="result"
    )

    template = """使用以下上下文來回答最後的問題。如果你不知道答案,就說你不知道,不要試圖編造答
    案。最多使用三句話。儘量使答案簡明扼要。總是在回答的最後說“謝謝你的提問!”。
    {context}
    問題: {question}
    """

    QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],
                                     template=template)
    qa_chain = RetrievalQA.from_chain_type(llm,
                                           retriever=vectordb.as_retriever(),
                                           return_source_documents=True,
                                           memory=memory,
                                           chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})

    question_1 = "什麼是南瓜書?"
    question_2 = "王陽明是誰?"

    # result = qa_chain({"query": question_1})
    # print("大模型+知識庫後回答 question_1 的結果:")
    # print(result["result"])

    question = "我可以學習到關於提示工程的知識嗎?"
    result = qa_chain({"query": question})
    print(result['result'])

部署知識庫助手

使用streamlit部署知識庫助手。

import openai
import streamlit as st
from langchain_openai import ChatOpenAI
import os
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from embed import ZhipuAIEmbeddings
from langchain.vectorstores.chroma import Chroma
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())  # read local .env file
openai.base_url = 'https://api.chatanywhere.tech/v1'

zhipuai_api_key = os.environ['ZHIPUAI_API_KEY']


def generate_response(input_text, openai_api_key):
    llm = ChatOpenAI(temperature=0.7, openai_api_key=openai_api_key, base_url=openai.base_url)
    output = llm.invoke(input_text)
    output_parser = StrOutputParser()
    output = output_parser.invoke(output)
    # st.info(output)
    return output


def get_vectordb():
    # 定義 Embeddings
    embedding = ZhipuAIEmbeddings()
    # 向量資料庫持久化路徑
    persist_directory = 'data_base/vector_db/chroma'
    # 載入資料庫
    vectordb = Chroma(
        persist_directory=persist_directory,  # 允許我們將persist_directory目錄儲存到磁碟上
        embedding_function=embedding
    )
    return vectordb


# 帶有歷史記錄的問答鏈
def get_chat_qa_chain(question: str, openai_api_key: str):
    vectordb = get_vectordb()
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=openai_api_key, base_url=openai.base_url)
    memory = ConversationBufferMemory(
        memory_key="chat_history",  # 與 prompt 的輸入變數保持一致。
        return_messages=True  # 將以訊息列表的形式返回聊天記錄,而不是單個字串
    )
    retriever = vectordb.as_retriever()
    qa = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        memory=memory
    )
    result = qa({"question": question})
    return result['answer']


# 不帶歷史記錄的問答鏈
def get_qa_chain(question: str, openai_api_key: str):
    vectordb = get_vectordb()
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=openai_api_key, base_url=openai.base_url)
    template = """使用以下上下文來回答最後的問題。如果你不知道答案,就說你不知道,不要試圖編造答
        案。最多使用三句話。儘量使答案簡明扼要。總是在回答的最後說“謝謝你的提問!”。
        {context}
        問題: {question}
        """
    QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],
                                     template=template)
    qa_chain = RetrievalQA.from_chain_type(llm,
                                           retriever=vectordb.as_retriever(),
                                           return_source_documents=True,
                                           chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})
    result = qa_chain({"query": question})
    return result["result"]


# Streamlit 應用程式介面
def main():
    st.title('🦜🔗 動手學大模型應用開發')
    openai_api_key = st.sidebar.text_input('OpenAI API Key', type='password')

    # 新增一個選擇按鈕來選擇不同的模型
    # selected_method = st.sidebar.selectbox("選擇模式", ["qa_chain", "chat_qa_chain", "None"])
    selected_method = st.radio(
        "你想選擇哪種模式進行對話?",
        ["None", "qa_chain", "chat_qa_chain"],
        captions=["不使用檢索問答的普通模式", "不帶歷史記錄的檢索問答模式", "帶歷史記錄的檢索問答模式"])

    # 用於跟蹤對話歷史
    if 'messages' not in st.session_state:
        st.session_state.messages = []

    messages = st.container(height=300)
    if prompt := st.chat_input("Say something"):
        # 將使用者輸入新增到對話歷史中
        st.session_state.messages.append({"role": "user", "text": prompt})

        if selected_method == "None":
            # 呼叫 respond 函式獲取回答
            answer = generate_response(prompt, openai_api_key)
        elif selected_method == "qa_chain":
            answer = get_qa_chain(prompt, openai_api_key)
        elif selected_method == "chat_qa_chain":
            answer = get_chat_qa_chain(prompt, openai_api_key)

        # 檢查回答是否為 None
        if answer is not None:
            # 將LLM的回答新增到對話歷史中
            st.session_state.messages.append({"role": "assistant", "text": answer})

        # 顯示整個對話歷史
        for message in st.session_state.messages:
            if message["role"] == "user":
                messages.chat_message("user").write(message["text"])
            elif message["role"] == "assistant":
                messages.chat_message("assistant").write(message["text"])


if __name__ == "__main__":
    main()

相關文章