構建RAG應用-day03: Chroma入門 本地embedding 智譜embedding

passion2021發表於2024-04-23

Chroma入門

使用chroma構建向量資料庫。使用了兩種embedding模型,可供自己選擇。
本地embedding:SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
封裝智譜embedding使得其可以在langchain中使用。

import os
from dotenv import load_dotenv, find_dotenv
from langchain.document_loaders.pdf import PyMuPDFLoader
from langchain.document_loaders.markdown import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain.embeddings import SentenceTransformerEmbeddings
from embed import ZhipuAIEmbeddings

_ = load_dotenv(find_dotenv())


# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
# os.environ["HTTP_PROXY"] = 'http://127.0.0.1:7890'

# 獲取folder_path下所有檔案路徑,儲存在file_paths裡
def generate_path(folder_path: str = '../data_base/knowledge_db') -> list:
    file_paths = []
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            file_path = os.path.join(root, file)
            file_paths.append(file_path)
    return file_paths


def generate_loaders(file_paths: list) -> list:
    loaders = []
    for file_path in file_paths:
        file_type = file_path.split('.')[-1]
        if file_type == 'pdf':
            loaders.append(PyMuPDFLoader(file_path))
        elif file_type == 'md':
            loaders.append(UnstructuredMarkdownLoader(file_path))
    return loaders


def exec_load(loaders: list) -> list:
    texts = []
    for loader in loaders:
        texts.extend(loader.load())
    return texts


def slice_docs(texts):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500, chunk_overlap=50)
    return text_splitter.split_documents(texts)


class VectorDB:
    # embedding = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
    embedding = ZhipuAIEmbeddings()
    persist_directory = '../data_base/vector_db/chroma'
    slice = 20

    def __init__(self, sliced_docs: list = None):
        assert sliced_docs is not None
        self.vectordb = Chroma.from_documents(
            documents=sliced_docs[:self.slice],  # 為了速度,只選擇前 20 個切分的 doc 進行生成;使用千帆時因QPS限制,建議選擇前 5 個doc
            embedding=self.embedding,
            persist_directory=self.persist_directory  # 允許我們將persist_directory目錄儲存到磁碟上
        )

    def persist(self):
        self.vectordb.persist()
        print(f"向量庫中儲存的數量:{self.vectordb._collection.count()}")

    def sim_search(self, query, k=3):
        sim_docs = self.vectordb.similarity_search(query, k=k)
        for i, sim_doc in enumerate(sim_docs, start=1):
            print(f"檢索到的第{i}個內容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
        return sim_docs

    def mmr_search(self, query, k=3):
        mmr_docs = self.vectordb.max_marginal_relevance_search(query, k=k)
        for i, sim_doc in enumerate(mmr_docs, start=1):
            print(f"MMR 檢索到的第{i}個內容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
        return mmr_docs


if __name__ == '__main__':
    # 讀取目錄下的所有檔案路徑
    file_paths = generate_path()
    # 根據檔案生成載入器
    loaders = generate_loaders(file_paths)
    # 執行文件載入
    texts = exec_load(loaders)
    # 切分文件
    sliced_docs = slice_docs(texts)
    # 構建向量資料庫
    vdb = VectorDB(sliced_docs)
    # 向量持久化儲存
    vdb.persist()
    # 定義問題
    question = "什麼是大語言模型"
    # 相似度檢索
    vdb.sim_search(question)
    # 最大邊際相關性(MMR) 檢索
    vdb.mmr_search(question)

langchain embedding封裝

需要一個智譜APIkey,官網註冊並且實名認證即可:智譜AI開放平臺 (bigmodel.cn)

from __future__ import annotations
import logging
from typing import Dict, List, Any
from dotenv import load_dotenv, find_dotenv
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator

logger = logging.getLogger(__name__)
_ = load_dotenv(find_dotenv())


# 在 Python 中,root_validator 是 Pydantic 模組中一個用於自定義資料校驗的裝飾器函式。root_validator 用於在校驗整個資料模型之前對整個資料模型進行自定義校驗,以確保所有的資料都符合所期望的資料結構。
class ZhipuAIEmbeddings(BaseModel, Embeddings):
    """`Zhipuai Embeddings` embedding models."""

    client: Any
    """`zhipuai.ZhipuAI"""

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """
        例項化ZhipuAI為values["client"]

        Args:

            values (Dict): 包含配置資訊的字典,必須包含 client 的欄位.
        Returns:

            values (Dict): 包含配置資訊的字典。如果環境中有zhipuai庫,則將返回例項化的ZhipuAI類;否則將報錯 'ModuleNotFoundError: No module named 'zhipuai''.
        """
        from zhipuai import ZhipuAI
        values["client"] = ZhipuAI()
        return values

    def embed_query(self, text: str) -> List[float]:
        """
        生成輸入文字的 embedding.

        Args:
            texts (str): 要生成 embedding 的文字.

        Return:
            embeddings (List[float]): 輸入文字的 embedding,一個浮點數值列表.
        """
        embeddings = self.client.embeddings.create(
            model="embedding-2",
            input=text
        )
        return embeddings.data[0].embedding

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """
        生成輸入文字列表的 embedding.
        Args:
            texts (List[str]): 要生成 embedding 的文字列表.

        Returns:
            List[List[float]]: 輸入列表中每個文件的 embedding 列表。每個 embedding 都表示為一個浮點值列表。
        """
        return [self.embed_query(text) for text in texts]

相關文章