Chroma入門
使用chroma構建向量資料庫。使用了兩種embedding模型,可供自己選擇。
本地embedding:SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
封裝智譜embedding使得其可以在langchain中使用。
import os
from dotenv import load_dotenv, find_dotenv
from langchain.document_loaders.pdf import PyMuPDFLoader
from langchain.document_loaders.markdown import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain.embeddings import SentenceTransformerEmbeddings
from embed import ZhipuAIEmbeddings
_ = load_dotenv(find_dotenv())
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
# os.environ["HTTP_PROXY"] = 'http://127.0.0.1:7890'
# 獲取folder_path下所有檔案路徑,儲存在file_paths裡
def generate_path(folder_path: str = '../data_base/knowledge_db') -> list:
file_paths = []
for root, dirs, files in os.walk(folder_path):
for file in files:
file_path = os.path.join(root, file)
file_paths.append(file_path)
return file_paths
def generate_loaders(file_paths: list) -> list:
loaders = []
for file_path in file_paths:
file_type = file_path.split('.')[-1]
if file_type == 'pdf':
loaders.append(PyMuPDFLoader(file_path))
elif file_type == 'md':
loaders.append(UnstructuredMarkdownLoader(file_path))
return loaders
def exec_load(loaders: list) -> list:
texts = []
for loader in loaders:
texts.extend(loader.load())
return texts
def slice_docs(texts):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=50)
return text_splitter.split_documents(texts)
class VectorDB:
# embedding = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
embedding = ZhipuAIEmbeddings()
persist_directory = '../data_base/vector_db/chroma'
slice = 20
def __init__(self, sliced_docs: list = None):
assert sliced_docs is not None
self.vectordb = Chroma.from_documents(
documents=sliced_docs[:self.slice], # 為了速度,只選擇前 20 個切分的 doc 進行生成;使用千帆時因QPS限制,建議選擇前 5 個doc
embedding=self.embedding,
persist_directory=self.persist_directory # 允許我們將persist_directory目錄儲存到磁碟上
)
def persist(self):
self.vectordb.persist()
print(f"向量庫中儲存的數量:{self.vectordb._collection.count()}")
def sim_search(self, query, k=3):
sim_docs = self.vectordb.similarity_search(query, k=k)
for i, sim_doc in enumerate(sim_docs, start=1):
print(f"檢索到的第{i}個內容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
return sim_docs
def mmr_search(self, query, k=3):
mmr_docs = self.vectordb.max_marginal_relevance_search(query, k=k)
for i, sim_doc in enumerate(mmr_docs, start=1):
print(f"MMR 檢索到的第{i}個內容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
return mmr_docs
if __name__ == '__main__':
# 讀取目錄下的所有檔案路徑
file_paths = generate_path()
# 根據檔案生成載入器
loaders = generate_loaders(file_paths)
# 執行文件載入
texts = exec_load(loaders)
# 切分文件
sliced_docs = slice_docs(texts)
# 構建向量資料庫
vdb = VectorDB(sliced_docs)
# 向量持久化儲存
vdb.persist()
# 定義問題
question = "什麼是大語言模型"
# 相似度檢索
vdb.sim_search(question)
# 最大邊際相關性(MMR) 檢索
vdb.mmr_search(question)
langchain embedding封裝
需要一個智譜APIkey,官網註冊並且實名認證即可:智譜AI開放平臺 (bigmodel.cn)
from __future__ import annotations
import logging
from typing import Dict, List, Any
from dotenv import load_dotenv, find_dotenv
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator
logger = logging.getLogger(__name__)
_ = load_dotenv(find_dotenv())
# 在 Python 中,root_validator 是 Pydantic 模組中一個用於自定義資料校驗的裝飾器函式。root_validator 用於在校驗整個資料模型之前對整個資料模型進行自定義校驗,以確保所有的資料都符合所期望的資料結構。
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""`Zhipuai Embeddings` embedding models."""
client: Any
"""`zhipuai.ZhipuAI"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
例項化ZhipuAI為values["client"]
Args:
values (Dict): 包含配置資訊的字典,必須包含 client 的欄位.
Returns:
values (Dict): 包含配置資訊的字典。如果環境中有zhipuai庫,則將返回例項化的ZhipuAI類;否則將報錯 'ModuleNotFoundError: No module named 'zhipuai''.
"""
from zhipuai import ZhipuAI
values["client"] = ZhipuAI()
return values
def embed_query(self, text: str) -> List[float]:
"""
生成輸入文字的 embedding.
Args:
texts (str): 要生成 embedding 的文字.
Return:
embeddings (List[float]): 輸入文字的 embedding,一個浮點數值列表.
"""
embeddings = self.client.embeddings.create(
model="embedding-2",
input=text
)
return embeddings.data[0].embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
生成輸入文字列表的 embedding.
Args:
texts (List[str]): 要生成 embedding 的文字列表.
Returns:
List[List[float]]: 輸入列表中每個文件的 embedding 列表。每個 embedding 都表示為一個浮點值列表。
"""
return [self.embed_query(text) for text in texts]