實施語義快取以改進 RAG 系統


在本筆記本中,我們將探索一個典型的 RAG 解決方案,其中我們將使用開源模型和向量資料庫 Chroma DB。但是,我們將整合一個語義快取系統,該系統將儲存各種使用者查詢,並決定是否生成包含來自向量資料庫或快取的資訊的提示。





大多數指導您建立 RAG 系統的教程都是為單使用者使用而設計的,旨在在測試環境中執行。換句話說,在筆記本中,與本地向量資料庫互動並進行 API 呼叫或使用本地儲存的模型。



在 RAG 系統中,有兩點很耗時:

  1. 檢索用於構建豐富提示的資訊
  2. 呼叫大型語言模型以獲取響應


將其放置在模型的響應點可能會導致對所獲得響應的影響喪失。我們的快取系統可以將“用 10 個字解釋法國大革命”和“用 100 個字解釋法國大革命”視為相同的查詢。如果我們的快取系統儲存模型響應,使用者可能會認為他們的指令沒有被準確遵循。




首先,我們需要安裝必要的 Python 包。

sentence transformers.:這個庫對於將句子轉換為固定長度的向量(也稱為嵌入)是必需的。
chromadb:這是我們的向量資料庫。ChromaDB 易於使用且開源,可能是用於儲存嵌入的最常用的向量資料庫。
accelerate:需要在 GPU 中執行模型。

!pip install -q transformers==4.38.1
!pip install -q accelerate==0.27.2
!pip install -q sentence-transformers==2.5.1
!pip install -q xformers==0.0.24
!pip install -q chromadb==0.4.24
!pip install -q datasets==2.17.1

import numpy as np
import pandas as pd


由於我們在一個自由且有限的空間內工作,並且只能使用幾 GB 的記憶體,因此我使用變數 MAX_ROWS 限制了資料集中要使用的行數。

#Login to Hugging Face. It is mandatory to use the Gemma Model,
#and recommended to acces public models and Datasets.
from getpass import getpass
if 'hf_key' not in locals():
  hf_key = getpass("Your Hugging Face API Key: ")
!huggingface-cli login --token $hf_key
Your Hugging Face API Key: ··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful
from datasets import load_dataset

data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split='train')
data = data.to_pandas()
0 susceptibility Who is at risk for Lymphocytic Choriomeningiti… LCMV infections can occur after exposure to fr… 0
1 symptoms What are the symptoms of Lymphocytic Choriomen… LCMV is most commonly recognized as causing ne… 1
2 susceptibility Who is at risk for Lymphocytic Choriomeningiti… Individuals of all ages who come into contact … 2
3 exams and tests How to diagnose Lymphocytic Choriomeningitis (… During the first phase of the disease, the mos… 3
4 treatment What are the treatments for Lymphocytic Chorio… Aseptic meningitis, encephalitis, or meningoen… 4
5 prevention How to prevent Lymphocytic Choriomeningitis (L… LCMV infection can be prevented by avoiding co… 5
6 information What is (are) Parasites – Cysticercosis ? Cysticercosis is an infection caused by the la… 6
7 susceptibility Who is at risk for Parasites – Cysticercosis? ? Cysticercosis is an infection caused by the la… 7
8 exams and tests How to diagnose Parasites – Cysticercosis ? If you think that you may have cysticercosis, … 8
9 treatment What are the treatments for Parasites – Cystic… Some people with cysticercosis do not need to … 9
MAX_ROWS = 15000

ChromaDB 要求資料具有唯一識別符號。我們可以使用此語句來實現,它將建立一個名為 Id 的新列。

#Because it is just a sample we select a small portion of News.
subset_data = data.head(MAX_ROWS)


我將使用最流行的開源向量資料庫 ChromaDB。

首先,我們需要匯入 ChromaDB,然後從 chromadb.config 模組匯入 Settings 類。該類允許我們更改 ChromaDB 系統的設定並自定義其行為。

import chromadb
from chromadb.config import Settings


chroma_client = chromadb.PersistentClient(path="/path/to/persist/directory")

5.填充和查詢 ChromaDB 資料庫

ChromaDB 中的資料儲存在集合中。如果集合存在,我們需要刪除它。

在接下來的幾行中,我們將透過呼叫上面建立的 chroma_client 中的 create_collection 函式來建立集合。

collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:

collection = chroma_client.create_collection(name=collection_name)

是時候將資料新增到集合中了。使用 add 函式,我們至少需要通知文件、後設資料和 ID。

在 ID 中,我們需要為每行通知一個唯一的識別符號。它必須是唯一的!我正在使用 MAX_ROWS 範圍建立 ID。

    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
    ids=[f"id{x}" for x in range(MAX_ROWS)],



讓我們定義一個函式來查詢 ChromaDB 資料庫。

def query_database(query_text, n_results=10):
    results = collection.query(query_texts=query_text, n_results=n_results )
    return results

6. 建立語義快取系統

為了實現快取系統,我們將使用 Faiss,這是一個允許將嵌入儲存在記憶體中的庫。它與 Chroma 所做的非常相似,但沒有永續性。

為此,我們將建立一個名為 semantic_cache 的類,它將與其自己的編碼器一起工作,併為使用者提供執行查詢所需的功能。

在這個類中,我們首先查詢 Faiss(快取),如果返回的結果高於指定的閾值,它將從快取中返回結果。否則,它將從 Chroma 資料庫中獲取結果。

快取儲存在 .json 檔案中。

!pip install -q faiss-cpu==1.8.0
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 27.0/27.0 MB 62.3 MB/s eta 0:00:00
import faiss
from sentence_transformers import SentenceTransformer
import time
import json


它採用 FlatLS 索引,這可能不是最快的,但非常適合小型資料集。根據要快取的資料的特徵和預期的資料集大小,可以使用其他索引,例如 HNSW 或 IVF。

def init_cache():
  index = faiss.IndexFlatL2(768)
  if index.is_trained:
            print('Index trained')

  # Initialize Sentence Transformer model
  encoder = SentenceTransformer('all-mpnet-base-v2')

  return index, encoder


def retrieve_cache(json_file):
          with open(json_file, 'r') as file:
              cache = json.load(file)
      except FileNotFoundError:
          cache = {'questions': [], 'embeddings': [], 'answers': [], 'response_text': []}

      return cache

store_cache 函式將包含快取資料的檔案儲存到磁碟。

def store_cache(json_file, cache):
  with open(json_file, 'w') as file:
        json.dump(cache, file)

這些函式將在 SemanticCache 類中使用,該類包括搜尋函式及其初始化函式。

儘管 ask 函式有大量程式碼,但其目的卻非常簡單。它在快取中查詢與使用者剛剛提出的問題最接近的問題。

然後,檢查它是否在指定的閾值內。如果是肯定的,它直接從快取中返回響應;否則,它呼叫 query_database 函式從 ChromaDB 中檢索資料。

我使用了歐幾里得距離而不是餘弦,後者在向量比較中被廣泛使用。這種選擇是基於歐幾里得距離是 Faiss 使用的預設度量這一事實。雖然也可以計算餘弦距離,但這樣做會增加複雜性,可能不會對最終結果產生重大影響。

class semantic_cache:
  def __init__(self, json_file="cache_file.json", thresold=0.35):
      # Initialize Faiss index with Euclidean distance
      self.index, self.encoder = init_cache()

      # Set Euclidean distance threshold
      # a distance of 0 means identicals sentences
      # We only return from cache sentences under this thresold
      self.euclidean_threshold = thresold

      self.json_file = json_file
      self.cache = retrieve_cache(self.json_file)

  def ask(self, question: str) -> str:
      # Method to retrieve an answer from the cache or generate a new one
      start_time = time.time()
          #First we obtain the embeddings corresponding to the user question
          embedding = self.encoder.encode([question])

          # Search for the nearest neighbor in the index
          self.index.nprobe = 8
          D, I = self.index.search(embedding, 1)

          if D[0] >= 0:
              if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                  row_id = int(I[0][0])

                  print('Answer recovered from Cache. ')
                  print(f'{D[0][0]:.3f} smaller than {self.euclidean_threshold}')
                  print(f'Found cache in row: {row_id} with score {D[0][0]:.3f}')
                  print(f'response_text: ' + self.cache['response_text'][row_id])

                  end_time = time.time()
                  elapsed_time = end_time - start_time
                  print(f"Time taken: {elapsed_time:.3f} seconds")
                  return self.cache['response_text'][row_id]

          # Handle the case when there are not enough results
          # or Euclidean distance is not met, asking to chromaDB.
          answer  = query_database([question], 1)
          response_text = answer['documents'][0][0]


          print('Answer recovered from ChromaDB. ')
          print(f'response_text: {response_text}')

          store_cache(self.json_file, self.cache)
          end_time = time.time()
          elapsed_time = end_time - start_time
          print(f"Time taken: {elapsed_time:.3f} seconds")

          return response_text
      except Exception as e:
          raise RuntimeError(f"Error during 'ask' method: {e}")

6.1 測試semantic_cache類。

# Initialize the cache.
cache = semantic_cache('4cache_file.json')
Index trained
results = cache.ask("How work a vaccine?")
Answer recovered from ChromaDB. 
response_text: Summary : Shots may hurt a little, but the diseases they can prevent are a lot worse. Some are even life-threatening. Immunization shots, or vaccinations, are essential. They protect against things like measles, mumps, rubella, hepatitis B, polio, tetanus, diphtheria, and pertussis (whooping cough). Immunizations are important for adults as well as children.    Your immune system helps your body fight germs by producing substances to combat them. Once it does, the immune system "remembers" the germ and can fight it again. Vaccines contain germs that have been killed or weakened. When given to a healthy person, the vaccine triggers the immune system to respond and thus build immunity.     Before vaccines, people became immune only by actually getting a disease and surviving it. Immunizations are an easier and less risky way to become immune.     NIH: National Institute of Allergy and Infectious Diseases
Time taken: 0.655 seconds

正如預期的那樣,該響應已從 ChromaDB 獲得。然後該類將其儲存在快取中。

現在,如果我們傳送第二個完全不同的問題,也應該從 ChromaDB 檢索響應。這是因為之前儲存的問題非常不同,以至於它會超出歐幾里得距離的指定閾值。

results = cache.ask("Explain briefly what is a Periodic Paralyses")
Answer recovered from ChromaDB. 
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.
Time taken: 0.083 seconds



在這種情況下,響應應該直接來自快取,而無需訪問 ChromaDB 資料庫。

results = cache.ask("Briefly explain me what is a periodic paralyses")
Answer recovered from Cache. 
0.018 smaller than 0.35
Found cache in row: 1 with score 0.018
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.
Time taken: 0.015 seconds



question_def = "Write in 20 words what is a periodic paralyses"
results = cache.ask(question_def)
Answer recovered from Cache. 
0.220 smaller than 0.35
Found cache in row: 1 with score 0.220
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.
Time taken: 0.017 seconds


7. 載入模型並建立提示

是時候使用庫 transformers 了,這是 hugging face 最著名的用於處理語言模型的庫。



AutoModelForCasualLLM:它提供了一個預訓練語言模型的介面,該模型專門為使用因果語言建模(例如 GPT 模型)的語言生成任務而設計,或者本筆記本中使用的模型 *Gemma-2b-it。

所選模型是 Gemma-2b-it。

請隨意測試不同的模型,您需要搜尋針對文字生成訓練的 NLP 模型。

!pip install torch
from torch import cuda, torch
#In a MAC Silicon the device must be 'mps'
# device = torch.device('mps') #to use with MAC Silicon
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

from transformers import AutoTokenizer, AutoModelForCausalLM

#model_id = "databricks/dolly-v2-3b"
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
8. 建立擴充套件提示




prompt_template = f"Relevant context: {results}\n\n The user's question: {question_def}"
"Relevant context: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.\n                \nThe two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.\n\n The user's question: Write in 20 words what is a periodic paralyses"
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")


outputs = model.generate(**input_ids,
<bos>Relevant context: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.

 The user's question: Write in 20 words what is a periodic paralyses?

Answer: A group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells.<eos>

9. 結論

訪問 ChromaDB 和直接訪問快取之間的效能提升約為 50%。然而,在較大的專案中,這種差異會增加,從而導致 90-95% 的提升。

我們在 Chroma 中的資料很少,並且只有一個快取類例項。通常,快取系統背後的資料要大得多,可能涉及的不僅僅是對向量資料庫的查詢,而是來自不同的地方。


總之,我們建立了一個非常簡單的 RAG(檢索增強生成)系統,並在使用者的問題和獲取建立豐富提示所需的資訊之間新增了一個語義快取層來增強它。
