個人知識庫助手
本文基於datawhale開源學習專案:llm-universe/docs/C6 at main · datawhalechina/llm-universe (github.com)
獲取資料庫
該專案llm-universe個人知識庫助手選用 Datawhale 一些經典開源課程、影片(部分)作為示例,具體包括:
- 《機器學習公式詳解》PDF版本
- 《面向開發者的 LLM 入門教程 第一部分 Prompt Engineering》md版本
- 《強化學習入門指南》MP4版本
- 以及datawhale總倉庫所有開源專案的readme https://github.com/datawhalechina
這些知識庫源資料放置在 /data_base/knowledge_db 目錄下,使用者也可以自己存放自己其他的檔案。
1.下面講一下如何獲取 DataWhale 總倉庫的所有開源專案的 readme ,使用者可以透過先執行 project/database/test_get_all_repo.py 檔案,用來獲取 Datawhale 總倉庫所有開源專案的 readme,程式碼如下:
import json
import requests
import os
import base64
import loguru
from dotenv import load_dotenv
# 載入環境變數
load_dotenv()
# github token
TOKEN = 'your github token'
# 定義獲取組織倉庫的函式
def get_repos(org_name, token, export_dir):
headers = {
'Authorization': f'token {token}',
}
url = f'https://api.github.com/orgs/{org_name}/repos'
response = requests.get(url, headers=headers, params={'per_page': 200, 'page': 0})
if response.status_code == 200:
repos = response.json()
loguru.logger.info(f'Fetched {len(repos)} repositories for {org_name}.')
# 使用 export_dir 確定儲存倉庫名的檔案路徑
repositories_path = r'E:\django_project\law_chatbot\test\task6\repositories.txt'
with open(repositories_path, 'w', encoding='utf-8') as file:
for repo in repos:
file.write(repo['name'] + '\n')
return repos
else:
loguru.logger.error(f"Error fetching repositories: {response.status_code}")
loguru.logger.error(response.text)
return []
# 定義拉取倉庫README檔案的函式
def fetch_repo_readme(org_name, repo_name, token, export_dir):
headers = {
'Authorization': f'token {token}',
}
url = f'https://api.github.com/repos/{org_name}/{repo_name}/readme'
response = requests.get(url, headers=headers)
if response.status_code == 200:
readme_content = response.json()['content']
# 解碼base64內容
readme_content = base64.b64decode(readme_content).decode('utf-8')
# 使用 export_dir 確定儲存 README 的檔案路徑
repo_dir = os.path.join(export_dir, repo_name)
if not os.path.exists(repo_dir):
os.makedirs(repo_dir)
readme_path = os.path.join(repo_dir, 'README.md')
with open(readme_path, 'w', encoding='utf-8') as file:
file.write(readme_content)
else:
loguru.logger.error(f"Error fetching README for {repo_name}: {response.status_code}")
loguru.logger.error(response.text)
# 主函式
if __name__ == '__main__':
# 配置組織名稱
org_name = 'datawhalechina'
# 配置 export_dir
export_dir = "./database/readme_db" # 請替換為實際的目錄路徑
# 獲取倉庫列表
repos = get_repos(org_name, TOKEN, export_dir)
# 列印倉庫名稱
if repos:
for repo in repos:
repo_name = repo['name']
# 拉取每個倉庫的README
fetch_repo_readme(org_name, repo_name, TOKEN, export_dir)
# 清理臨時資料夾
# if os.path.exists('temp'):
# shutil.rmtree('temp')
這裡可能需要你自己的github token,獲取方法如下:
1. 開啟Github官方網站並登入您的賬號。
2. 在右上角的選單中,選擇”Settings”。
3. 在設定頁面中選擇”Developer settings”選項卡。
4. 在左側的選單中選擇”Personal access tokens”。
5. 點選”Generate new token”按鈕來生成一個新的Token。
使用llm進行摘要處理
這些readme檔案含有不少無關資訊,我們使用llm進行摘要處理:
(原文件使用的openai==0.28,這裡使用新版本的openai包)
import os
from dotenv import load_dotenv
import openai
from get_data import get_repos
from bs4 import BeautifulSoup
import markdown
import re
import time
from openai import OpenAI
import openai
# Load environment variables
load_dotenv()
TOKEN = 'your token'
# Set up the OpenAI API client
openai_api_key = os.environ["OPENAI_API_KEY"]
openai.base_url = 'https://api.chatanywhere.tech/v1'
# 過濾文字中連結防止大語言模型風控
def remove_urls(text):
# 正規表示式模式,用於匹配URL
url_pattern = re.compile(r'https?://[^\s]*')
# 替換所有匹配的URL為空字串
text = re.sub(url_pattern, '', text)
# 正規表示式模式,用於匹配特定的文字
specific_text_pattern = re.compile(r'掃描下方二維碼關注公眾號|提取碼|關注||回覆關鍵詞|侵權|版權|致謝|引用|LICENSE'
r'|組隊打卡|任務打卡|組隊學習的那些事|學習週期|開源內容|打卡|組隊學習|連結')
# 替換所有匹配的特定文字為空字串
text = re.sub(specific_text_pattern, '', text)
return text
# 抽取md中的文字
def extract_text_from_md(md_content):
# Convert Markdown to HTML
html = markdown.markdown(md_content)
# Use BeautifulSoup to extract text
soup = BeautifulSoup(html, 'html.parser')
return remove_urls(soup.get_text())
def generate_llm_summary(repo_name, readme_content, model):
prompt = f"1:這個倉庫名是 {repo_name}. 此倉庫的readme全部內容是: {readme_content}\
2:請用約200以內的中文概括這個倉庫readme的內容,返回的概括格式要求:這個倉庫名是...,這倉庫內容主要是..."
openai.api_key = openai_api_key
# 具體呼叫
messages = [{"role": "system", "content": "你是一個人工智慧助手"},
{"role": "user", "content": prompt}]
llm = OpenAI(base_url=openai.base_url, )
response = llm.chat.completions.create(
model=model,
messages=messages,
)
return response.choices[0].message.content
def main(org_name, export_dir, summary_dir, model):
repos = get_repos(org_name, TOKEN, export_dir)
# Create a directory to save summaries
os.makedirs(summary_dir, exist_ok=True)
for id, repo in enumerate(repos):
repo_name = repo['name']
readme_path = os.path.join(export_dir, repo_name, 'README.md')
print(repo_name)
if os.path.exists(readme_path):
with open(readme_path, 'r', encoding='utf-8') as file:
readme_content = file.read()
# Extract text from the README
readme_text = extract_text_from_md(readme_content)
# Generate a summary for the README
# 訪問受限,每min一次
# time.sleep(60)
print('第' + str(id) + '條' + 'summary開始')
try:
summary = generate_llm_summary(repo_name, readme_text, model)
print(summary)
# Write summary to a Markdown file in the summary directory
summary_file_path = os.path.join(summary_dir, f"{repo_name}_summary.md")
with open(summary_file_path, 'w', encoding='utf-8') as summary_file:
summary_file.write(f"# {repo_name} Summary\n\n")
summary_file.write(summary)
except openai.OpenAIError as e:
summary_file_path = os.path.join(summary_dir, f"{repo_name}_summary風控.md")
with open(summary_file_path, 'w', encoding='utf-8') as summary_file:
summary_file.write(f"# {repo_name} Summary風控\n\n")
summary_file.write("README內容風控。\n")
print(f"Error generating summary for {repo_name}: {e}")
# print(readme_text)
else:
print(f"檔案不存在: {readme_path}")
# If README doesn't exist, create an empty Markdown file
summary_file_path = os.path.join(summary_dir, f"{repo_name}_summary不存在.md")
with open(summary_file_path, 'w', encoding='utf-8') as summary_file:
summary_file.write(f"# {repo_name} Summary不存在\n\n")
summary_file.write("README檔案不存在。\n")
if __name__ == '__main__':
# 配置組織名稱
org_name = 'datawhalechina'
# 配置 export_dir
export_dir = "./database/readme_db" # 請替換為實際readme的目錄路徑
summary_dir = "./data_base/knowledge_db/readme_summary" # 請替換為實際readme的概括的目錄路徑
model = "gpt-3.5-turbo" # deepseek-chat,gpt-3.5-turbo,moonshot-v1-8k
main(org_name, export_dir, summary_dir, model)
得到各個readme的摘要資訊,共100條:
使用智譜ai構建向量資料庫
import os
import sys
import re
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import tempfile
from dotenv import load_dotenv, find_dotenv
from embed import ZhipuAIEmbeddings
from langchain.document_loaders import UnstructuredFileLoader
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyMuPDFLoader
from langchain.vectorstores import Chroma
# 首先實現基本配置
DEFAULT_DB_PATH = "data_base/knowledge_db/readme_summary"
DEFAULT_PERSIST_PATH = "./vector_db"
def get_files(dir_path):
file_list = []
for filepath, dirnames, filenames in os.walk(dir_path):
for filename in filenames:
file_list.append(os.path.join(filepath, filename))
return file_list
def file_loader(file, loaders):
if isinstance(file, tempfile._TemporaryFileWrapper):
file = file.name
if not os.path.isfile(file):
[file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)]
return
file_type = file.split('.')[-1]
if file_type == 'pdf':
loaders.append(PyMuPDFLoader(file))
elif file_type == 'md':
pattern = r"不存在|風控"
match = re.search(pattern, file)
if not match:
loaders.append(UnstructuredMarkdownLoader(file))
elif file_type == 'txt':
loaders.append(UnstructuredFileLoader(file))
return
def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH):
if embeddings == 'openai' or embeddings == 'm3e' or embeddings == 'zhipuai':
vectordb = create_db(files, persist_directory, embeddings)
return ""
def get_embedding(embedding: str, embedding_key: str = None, env_file: str = None):
if embedding == "zhipuai":
return ZhipuAIEmbeddings(zhipuai_api_key=os.environ['ZHIPUAI_API_KEY'])
else:
raise ValueError(f"embedding {embedding} not support ")
def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"):
"""
該函式用於載入 PDF 檔案,切分文件,生成文件的嵌入向量,建立向量資料庫。
引數:
file: 存放檔案的路徑。
embeddings: 用於生產 Embedding 的模型
返回:
vectordb: 建立的資料庫。
"""
if files == None:
return "can't load empty file"
if type(files) != list:
files = [files]
loaders = []
[file_loader(file, loaders) for file in files]
docs = []
for loader in loaders:
if loader is not None:
docs.extend(loader.load())
# 切分文件
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=150)
split_docs = text_splitter.split_documents(docs)
if type(embeddings) == str:
embeddings = get_embedding(embeddings)
# 定義持久化路徑
persist_directory = './vector_db/chroma'
# 載入資料庫
vectordb = Chroma.from_documents(
documents=split_docs,
embedding=embeddings,
persist_directory=persist_directory # 允許我們將persist_directory目錄儲存到磁碟上
)
vectordb.persist()
return vectordb
def presit_knowledge_db(vectordb):
"""
該函式用於持久化向量資料庫。
引數:
vectordb: 要持久化的向量資料庫。
"""
vectordb.persist()
def load_knowledge_db(path, embeddings):
"""
該函式用於載入向量資料庫。
引數:
path: 要載入的向量資料庫路徑。
embeddings: 向量資料庫使用的 embedding 模型。
返回:
vectordb: 載入的資料庫。
"""
vectordb = Chroma(
persist_directory=path,
embedding_function=embeddings
)
return vectordb
if __name__ == "__main__":
create_db(embeddings="zhipuai")
使用智譜ai構建問答鏈
import os
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
import openai
from embed import ZhipuAIEmbeddings
openai.base_url = 'https://api.chatanywhere.tech/v1'
chatgpt = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, base_url=openai.base_url)
persist_directory = 'vector_db/chroma'
embedding = ZhipuAIEmbeddings(zhipuai_api_key=os.environ['ZHIPUAI_API_KEY'])
vectordb = Chroma(
persist_directory=persist_directory, # 允許我們將persist_directory目錄儲存到磁碟上
embedding_function=embedding
)
class Chat_QA_chain_self:
""""
帶歷史記錄的問答鏈
- model:呼叫的模型名稱
- temperature:溫度係數,控制生成的隨機性
- top_k:返回檢索的前k個相似文件
- chat_history:歷史記錄,輸入一個列表,預設是一個空列表
- history_len:控制保留的最近 history_len 次對話
- file_path:建庫檔案所在路徑
- persist_path:向量資料庫持久化路徑
- appid:星火
- api_key:星火、百度文心、OpenAI、智譜都需要傳遞的引數
- Spark_api_secret:星火秘鑰
- Wenxin_secret_key:文心秘鑰
- embeddings:使用的embedding模型
- embedding_key:使用的embedding模型的秘鑰(智譜或者OpenAI)
"""
def __init__(self, model: str, temperature: float = 0.0, top_k: int = 4, chat_history: list = [],
file_path: str = None, persist_path: str = None, appid: str = None, api_key: str = None,
Spark_api_secret: str = None, Wenxin_secret_key: str = None, embedding="openai",
embedding_key: str = None):
self.model = model
self.temperature = temperature
self.top_k = top_k
self.chat_history = chat_history
# self.history_len = history_len
self.file_path = file_path
self.persist_path = persist_path
self.appid = appid
self.api_key = api_key
self.Spark_api_secret = Spark_api_secret
self.Wenxin_secret_key = Wenxin_secret_key
self.embedding = embedding
self.embedding_key = embedding_key
# self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding, self.embedding_key)
self.vectordb = vectordb
def clear_history(self):
"清空歷史記錄"
return self.chat_history.clear()
def change_history_length(self, history_len: int = 1):
"""
儲存指定對話輪次的歷史記錄
輸入引數:
- history_len :控制保留的最近 history_len 次對話
- chat_history:當前的歷史對話記錄
輸出:返回最近 history_len 次對話
"""
n = len(self.chat_history)
return self.chat_history[n - history_len:]
def answer(self, question: str = None, temperature=None, top_k=4):
""""
核心方法,呼叫問答鏈
arguments:
- question:使用者提問
"""
if len(question) == 0:
return "", self.chat_history
if len(question) == 0:
return ""
if temperature == None:
temperature = self.temperature
# llm = model_to_llm(self.model, temperature, self.appid, self.api_key, self.Spark_api_secret,
# self.Wenxin_secret_key)
llm = chatgpt
# self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
retriever = self.vectordb.as_retriever(search_type="similarity",
search_kwargs={'k': top_k}) # 預設similarity,k=4
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever
)
# print(self.llm)
result = qa({"question": question, "chat_history": self.chat_history}) # result裡有question、chat_history、answer
answer = result['answer']
self.chat_history.append((question, answer)) # 更新歷史記錄
return self.chat_history # 返回本次回答和更新後的歷史記錄
if __name__ == '__main__':
question_1 = "給我介紹1個 Datawhale 的機器學習專案"
qa_chain = Chat_QA_chain_self(model="gpt-3.5-turbo")
result = qa_chain.answer(question=question_1)
print("大模型+知識庫後回答 question_1 的結果:")
print(result[0][1])
檢視效果: