實施語義快取以改進 RAG 系統
1.快取介紹
在本筆記本中,我們將探索一個典型的 RAG 解決方案,其中我們將使用開源模型和向量資料庫 Chroma DB。但是,我們將整合一個語義快取系統,該系統將儲存各種使用者查詢,並決定是否生成包含來自向量資料庫或快取的資訊的提示。
語義快取系統旨在識別相似或相同的使用者請求。當找到匹配的請求時,系統會從快取中檢索相應的資訊,從而減少從原始源獲取它的需要。
由於比較考慮了請求的語義含義,因此它們不必完全相同,系統就可以將它們識別為同一個問題。它們可以以不同的方式表達或包含不準確之處,無論是印刷錯誤還是句子結構,我們都可以確定使用者實際上正在請求相同的資訊。
例如,像“法國的首都是什麼?”、告訴我法國首都的名字?和“法國的首都是什麼?”這樣的查詢都傳達了相同的意圖,應該被識別為同一個問題。
雖然模型的響應可能因第二個示例中對簡潔答案的請求而有所不同,但從向量資料庫檢索到的資訊應該是相同的。這就是為什麼我將快取系統放在使用者和向量資料庫之間,而不是使用者和大型語言模型之間。
大多數指導您建立 RAG 系統的教程都是為單使用者使用而設計的,旨在在測試環境中執行。換句話說,在筆記本中,與本地向量資料庫互動並進行 API 呼叫或使用本地儲存的模型。
當嘗試將其中一個模型轉換為生產時,這種架構很快就會變得不足,因為它們可能會遇到數十到數千個重複請求。
提高效能的一種方法是透過一個或多個語義快取。此快取保留先前請求的結果,在解析新請求之前,它會檢查之前是否收到過類似的請求。如果是,它不會重新執行該過程,而是從快取中檢索資訊。
在 RAG 系統中,有兩點很耗時:
- 檢索用於構建豐富提示的資訊
- 呼叫大型語言模型以獲取響應
在這兩個點中,都可以實現語義快取系統,我們甚至可以有兩個快取,每個點一個。
將其放置在模型的響應點可能會導致對所獲得響應的影響喪失。我們的快取系統可以將“用 10 個字解釋法國大革命”和“用 100 個字解釋法國大革命”視為相同的查詢。如果我們的快取系統儲存模型響應,使用者可能會認為他們的指令沒有被準確遵循。
但這兩個請求都需要相同的資訊來豐富提示。這就是我選擇將語義快取系統放置在使用者請求和從向量資料庫檢索資訊之間的主要原因。
但是,這是一個設計決策。根據響應和系統請求的型別,它可以放在一個點或另一個點。很明顯,快取模型響應可以節省最多的時間,但正如我已經解釋過的,這是以失去使用者對響應的影響為代價的。
2.匯入並載入庫。
首先,我們需要安裝必要的 Python 包。
sentence transformers.:這個庫對於將句子轉換為固定長度的向量(也稱為嵌入)是必需的。
xformers:它是一個提供庫和實用程式的包,以方便使用轉換器模型。我們需要安裝它以避免在使用模型和嵌入時出現錯誤。
chromadb:這是我們的向量資料庫。ChromaDB 易於使用且開源,可能是用於儲存嵌入的最常用的向量資料庫。
accelerate:需要在 GPU 中執行模型。
!pip install -q transformers==4.38.1
!pip install -q accelerate==0.27.2
!pip install -q sentence-transformers==2.5.1
!pip install -q xformers==0.0.24
!pip install -q chromadb==0.4.24
!pip install -q datasets==2.17.1
import numpy as np
import pandas as pd
3.載入資料集
由於我們在一個自由且有限的空間內工作,並且只能使用幾 GB 的記憶體,因此我使用變數 MAX_ROWS 限制了資料集中要使用的行數。
#Login to Hugging Face. It is mandatory to use the Gemma Model,
#and recommended to acces public models and Datasets.
from getpass import getpass
if 'hf_key' not in locals():
hf_key = getpass("Your Hugging Face API Key: ")
!huggingface-cli login --token $hf_key
Your Hugging Face API Key: ··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful
from datasets import load_dataset
data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split='train')
data = data.to_pandas()
data["id"]=data.index
data.head(10)
qtype | Question | Answer | id | |
---|---|---|---|---|
0 | susceptibility | Who is at risk for Lymphocytic Choriomeningiti… | LCMV infections can occur after exposure to fr… | 0 |
1 | symptoms | What are the symptoms of Lymphocytic Choriomen… | LCMV is most commonly recognized as causing ne… | 1 |
2 | susceptibility | Who is at risk for Lymphocytic Choriomeningiti… | Individuals of all ages who come into contact … | 2 |
3 | exams and tests | How to diagnose Lymphocytic Choriomeningitis (… | During the first phase of the disease, the mos… | 3 |
4 | treatment | What are the treatments for Lymphocytic Chorio… | Aseptic meningitis, encephalitis, or meningoen… | 4 |
5 | prevention | How to prevent Lymphocytic Choriomeningitis (L… | LCMV infection can be prevented by avoiding co… | 5 |
6 | information | What is (are) Parasites – Cysticercosis ? | Cysticercosis is an infection caused by the la… | 6 |
7 | susceptibility | Who is at risk for Parasites – Cysticercosis? ? | Cysticercosis is an infection caused by the la… | 7 |
8 | exams and tests | How to diagnose Parasites – Cysticercosis ? | If you think that you may have cysticercosis, … | 8 |
9 | treatment | What are the treatments for Parasites – Cystic… | Some people with cysticercosis do not need to … | 9 |
MAX_ROWS = 15000
DOCUMENT="Answer"
TOPIC="qtype"
ChromaDB 要求資料具有唯一識別符號。我們可以使用此語句來實現,它將建立一個名為 Id 的新列。
#Because it is just a sample we select a small portion of News.
subset_data = data.head(MAX_ROWS)
4.匯入和配置向量資料庫
我將使用最流行的開源向量資料庫 ChromaDB。
首先,我們需要匯入 ChromaDB,然後從 chromadb.config 模組匯入 Settings 類。該類允許我們更改 ChromaDB 系統的設定並自定義其行為。
import chromadb
from chromadb.config import Settings
現在我們只需要指明向量資料庫的儲存路徑。
chroma_client = chromadb.PersistentClient(path="/path/to/persist/directory")
5.填充和查詢 ChromaDB 資料庫
ChromaDB 中的資料儲存在集合中。如果集合存在,我們需要刪除它。
在接下來的幾行中,我們將透過呼叫上面建立的 chroma_client 中的 create_collection 函式來建立集合。
collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
chroma_client.delete_collection(name=collection_name)
collection = chroma_client.create_collection(name=collection_name)
是時候將資料新增到集合中了。使用 add 函式,我們至少需要通知文件、後設資料和 ID。
在文件中,我們儲存大文字,它是每個資料集中的不同列。
在後設資料中,我們可以通知主題列表。
在 ID 中,我們需要為每行通知一個唯一的識別符號。它必須是唯一的!我正在使用 MAX_ROWS 範圍建立 ID。
collection.add(
documents=subset_data[DOCUMENT].tolist(),
metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
ids=[f"id{x}" for x in range(MAX_ROWS)],
)
一旦我們獲得了資料庫中的資訊,我們就可以查詢它,並請求符合我們需求的資料。搜尋是在文件內容內進行的,它不會查詢確切的單詞或短語。結果將基於搜尋詞和文件內容之間的相似性。
後設資料不用於搜尋,但可用於在初始搜尋後過濾或最佳化結果。
讓我們定義一個函式來查詢 ChromaDB 資料庫。
def query_database(query_text, n_results=10):
results = collection.query(query_texts=query_text, n_results=n_results )
return results
6. 建立語義快取系統
為了實現快取系統,我們將使用 Faiss,這是一個允許將嵌入儲存在記憶體中的庫。它與 Chroma 所做的非常相似,但沒有永續性。
為此,我們將建立一個名為 semantic_cache 的類,它將與其自己的編碼器一起工作,併為使用者提供執行查詢所需的功能。
在這個類中,我們首先查詢 Faiss(快取),如果返回的結果高於指定的閾值,它將從快取中返回結果。否則,它將從 Chroma 資料庫中獲取結果。
快取儲存在 .json 檔案中。
!pip install -q faiss-cpu==1.8.0
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 27.0/27.0 MB 62.3 MB/s eta 0:00:00
import faiss
from sentence_transformers import SentenceTransformer
import time
import json
此函式初始化語義快取。
它採用 FlatLS 索引,這可能不是最快的,但非常適合小型資料集。根據要快取的資料的特徵和預期的資料集大小,可以使用其他索引,例如 HNSW 或 IVF。
def init_cache():
index = faiss.IndexFlatL2(768)
if index.is_trained:
print('Index trained')
# Initialize Sentence Transformer model
encoder = SentenceTransformer('all-mpnet-base-v2')
return index, encoder
在retrieve_cache函式中,如果需要在會話間重用快取,則會從磁碟檢索.json檔案。
def retrieve_cache(json_file):
try:
with open(json_file, 'r') as file:
cache = json.load(file)
except FileNotFoundError:
cache = {'questions': [], 'embeddings': [], 'answers': [], 'response_text': []}
return cache
store_cache 函式將包含快取資料的檔案儲存到磁碟。
def store_cache(json_file, cache):
with open(json_file, 'w') as file:
json.dump(cache, file)
這些函式將在 SemanticCache 類中使用,該類包括搜尋函式及其初始化函式。
儘管 ask 函式有大量程式碼,但其目的卻非常簡單。它在快取中查詢與使用者剛剛提出的問題最接近的問題。
然後,檢查它是否在指定的閾值內。如果是肯定的,它直接從快取中返回響應;否則,它呼叫 query_database 函式從 ChromaDB 中檢索資料。
我使用了歐幾里得距離而不是餘弦,後者在向量比較中被廣泛使用。這種選擇是基於歐幾里得距離是 Faiss 使用的預設度量這一事實。雖然也可以計算餘弦距離,但這樣做會增加複雜性,可能不會對最終結果產生重大影響。
class semantic_cache:
def __init__(self, json_file="cache_file.json", thresold=0.35):
# Initialize Faiss index with Euclidean distance
self.index, self.encoder = init_cache()
# Set Euclidean distance threshold
# a distance of 0 means identicals sentences
# We only return from cache sentences under this thresold
self.euclidean_threshold = thresold
self.json_file = json_file
self.cache = retrieve_cache(self.json_file)
def ask(self, question: str) -> str:
# Method to retrieve an answer from the cache or generate a new one
start_time = time.time()
try:
#First we obtain the embeddings corresponding to the user question
embedding = self.encoder.encode([question])
# Search for the nearest neighbor in the index
self.index.nprobe = 8
D, I = self.index.search(embedding, 1)
if D[0] >= 0:
if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
row_id = int(I[0][0])
print('Answer recovered from Cache. ')
print(f'{D[0][0]:.3f} smaller than {self.euclidean_threshold}')
print(f'Found cache in row: {row_id} with score {D[0][0]:.3f}')
print(f'response_text: ' + self.cache['response_text'][row_id])
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Time taken: {elapsed_time:.3f} seconds")
return self.cache['response_text'][row_id]
# Handle the case when there are not enough results
# or Euclidean distance is not met, asking to chromaDB.
answer = query_database([question], 1)
response_text = answer['documents'][0][0]
self.cache['questions'].append(question)
self.cache['embeddings'].append(embedding[0].tolist())
self.cache['answers'].append(answer)
self.cache['response_text'].append(response_text