實施語義快取以改進 RAG 系統

bonelee發表於2024-11-28

實施語義快取以改進 RAG 系統

1.快取介紹

在本筆記本中,我們將探索一個典型的 RAG 解決方案,其中我們將使用開源模型和向量資料庫 Chroma DB。但是,我們將整合一個語義快取系統,該系統將儲存各種使用者查詢,並決定是否生成包含來自向量資料庫或快取的資訊的提示。

語義快取系統旨在識別相似或相同的使用者請求。當找到匹配的請求時,系統會從快取中檢索相應的資訊,從而減少從原始源獲取它的需要。

由於比較考慮了請求的語義含義,因此它們不必完全相同,系統就可以將它們識別為同一個問題。它們可以以不同的方式表達或包含不準確之處,無論是印刷錯誤還是句子結構,我們都可以確定使用者實際上正在請求相同的資訊。

例如,像“法國的首都是什麼?”、告訴我法國首都的名字?和“法國的首都是什麼?”這樣的查詢都傳達了相同的意圖,應該被識別為同一個問題。

雖然模型的響應可能因第二個示例中對簡潔答案的請求而有所不同,但從向量資料庫檢索到的資訊應該是相同的。這就是為什麼我將快取系統放在使用者和向量資料庫之間,而不是使用者和大型語言模型之間。

大多數指導您建立 RAG 系統的教程都是為單使用者使用而設計的,旨在在測試環境中執行。換句話說,在筆記本中,與本地向量資料庫互動並進行 API 呼叫或使用本地儲存的模型。

當嘗試將其中一個模型轉換為生產時,這種架構很快就會變得不足,因為它們可能會遇到數十到數千個重複請求。

提高效能的一種方法是透過一個或多個語義快取。此快取保留先前請求的結果,在解析新請求之前,它會檢查之前是否收到過類似的請求。如果是,它不會重新執行該過程,而是從快取中檢索資訊。

在 RAG 系統中,有兩點很耗時:

  1. 檢索用於構建豐富提示的資訊
  2. 呼叫大型語言模型以獲取響應

在這兩個點中,都可以實現語義快取系統,我們甚至可以有兩個快取,每個點一個。

將其放置在模型的響應點可能會導致對所獲得響應的影響喪失。我們的快取系統可以將“用 10 個字解釋法國大革命”和“用 100 個字解釋法國大革命”視為相同的查詢。如果我們的快取系統儲存模型響應,使用者可能會認為他們的指令沒有被準確遵循。

但這兩個請求都需要相同的資訊來豐富提示。這就是我選擇將語義快取系統放置在使用者請求和從向量資料庫檢索資訊之間的主要原因。

但是,這是一個設計決策。根據響應和系統請求的型別,它可以放在一個點或另一個點。很明顯,快取模型響應可以節省最多的時間,但正如我已經解釋過的,這是以失去使用者對響應的影響為代價的。

2.匯入並載入庫。

首先,我們需要安裝必要的 Python 包。

sentence transformers.:這個庫對於將句子轉換為固定長度的向量(也稱為嵌入)是必需的。
xformers:它是一個提供庫和實用程式的包,以方便使用轉換器模型。我們需要安裝它以避免在使用模型和嵌入時出現錯誤。
chromadb:這是我們的向量資料庫。ChromaDB 易於使用且開源,可能是用於儲存嵌入的最常用的向量資料庫。
accelerate:需要在 GPU 中執行模型。

!pip install -q transformers==4.38.1
!pip install -q accelerate==0.27.2
!pip install -q sentence-transformers==2.5.1
!pip install -q xformers==0.0.24
!pip install -q chromadb==0.4.24
!pip install -q datasets==2.17.1

import numpy as np
import pandas as pd

3.載入資料集

由於我們在一個自由且有限的空間內工作,並且只能使用幾 GB 的記憶體,因此我使用變數 MAX_ROWS 限制了資料集中要使用的行數。

#Login to Hugging Face. It is mandatory to use the Gemma Model,
#and recommended to acces public models and Datasets.
from getpass import getpass
if 'hf_key' not in locals():
  hf_key = getpass("Your Hugging Face API Key: ")
!huggingface-cli login --token $hf_key
Your Hugging Face API Key: ··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful
from datasets import load_dataset

data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split='train')
data = data.to_pandas()
data["id"]=data.index
data.head(10)
qtypeQuestionAnswerid
0 susceptibility Who is at risk for Lymphocytic Choriomeningiti… LCMV infections can occur after exposure to fr… 0
1 symptoms What are the symptoms of Lymphocytic Choriomen… LCMV is most commonly recognized as causing ne… 1
2 susceptibility Who is at risk for Lymphocytic Choriomeningiti… Individuals of all ages who come into contact … 2
3 exams and tests How to diagnose Lymphocytic Choriomeningitis (… During the first phase of the disease, the mos… 3
4 treatment What are the treatments for Lymphocytic Chorio… Aseptic meningitis, encephalitis, or meningoen… 4
5 prevention How to prevent Lymphocytic Choriomeningitis (L… LCMV infection can be prevented by avoiding co… 5
6 information What is (are) Parasites – Cysticercosis ? Cysticercosis is an infection caused by the la… 6
7 susceptibility Who is at risk for Parasites – Cysticercosis? ? Cysticercosis is an infection caused by the la… 7
8 exams and tests How to diagnose Parasites – Cysticercosis ? If you think that you may have cysticercosis, … 8
9 treatment What are the treatments for Parasites – Cystic… Some people with cysticercosis do not need to … 9
MAX_ROWS = 15000
DOCUMENT="Answer"
TOPIC="qtype"

ChromaDB 要求資料具有唯一識別符號。我們可以使用此語句來實現,它將建立一個名為 Id 的新列。

#Because it is just a sample we select a small portion of News.
subset_data = data.head(MAX_ROWS)

4.匯入和配置向量資料庫

我將使用最流行的開源向量資料庫 ChromaDB。

首先,我們需要匯入 ChromaDB,然後從 chromadb.config 模組匯入 Settings 類。該類允許我們更改 ChromaDB 系統的設定並自定義其行為。

import chromadb
from chromadb.config import Settings

現在我們只需要指明向量資料庫的儲存路徑。

chroma_client = chromadb.PersistentClient(path="/path/to/persist/directory")

5.填充和查詢 ChromaDB 資料庫

ChromaDB 中的資料儲存在集合中。如果集合存在,我們需要刪除它。

在接下來的幾行中,我們將透過呼叫上面建立的 chroma_client 中的 create_collection 函式來建立集合。

collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
        chroma_client.delete_collection(name=collection_name)

collection = chroma_client.create_collection(name=collection_name)

是時候將資料新增到集合中了。使用 add 函式,我們至少需要通知文件、後設資料和 ID。

在文件中,我們儲存大文字,它是每個資料集中的不同列。
在後設資料中,我們可以通知主題列表。
在 ID 中,我們需要為每行通知一個唯一的識別符號。它必須是唯一的!我正在使用 MAX_ROWS 範圍建立 ID。

collection.add(
    documents=subset_data[DOCUMENT].tolist(),
    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
    ids=[f"id{x}" for x in range(MAX_ROWS)],
)

一旦我們獲得了資料庫中的資訊,我們就可以查詢它,並請求符合我們需求的資料。搜尋是在文件內容內進行的,它不會查詢確切的單詞或短語。結果將基於搜尋詞和文件內容之間的相似性。

後設資料不用於搜尋,但可用於在初始搜尋後過濾或最佳化結果。

讓我們定義一個函式來查詢 ChromaDB 資料庫。

def query_database(query_text, n_results=10):
    results = collection.query(query_texts=query_text, n_results=n_results )
    return results

6. 建立語義快取系統

為了實現快取系統,我們將使用 Faiss,這是一個允許將嵌入儲存在記憶體中的庫。它與 Chroma 所做的非常相似,但沒有永續性。

為此,我們將建立一個名為 semantic_cache 的類,它將與其自己的編碼器一起工作,併為使用者提供執行查詢所需的功能。

在這個類中,我們首先查詢 Faiss(快取),如果返回的結果高於指定的閾值,它將從快取中返回結果。否則,它將從 Chroma 資料庫中獲取結果。

快取儲存在 .json 檔案中。

!pip install -q faiss-cpu==1.8.0
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 27.0/27.0 MB 62.3 MB/s eta 0:00:00
import faiss
from sentence_transformers import SentenceTransformer
import time
import json

此函式初始化語義快取。

它採用 FlatLS 索引,這可能不是最快的,但非常適合小型資料集。根據要快取的資料的特徵和預期的資料集大小,可以使用其他索引,例如 HNSW 或 IVF。

def init_cache():
  index = faiss.IndexFlatL2(768)
  if index.is_trained:
            print('Index trained')

  # Initialize Sentence Transformer model
  encoder = SentenceTransformer('all-mpnet-base-v2')

  return index, encoder

在retrieve_cache函式中,如果需要在會話間重用快取,則會從磁碟檢索.json檔案。

def retrieve_cache(json_file):
      try:
          with open(json_file, 'r') as file:
              cache = json.load(file)
      except FileNotFoundError:
          cache = {'questions': [], 'embeddings': [], 'answers': [], 'response_text': []}

      return cache

store_cache 函式將包含快取資料的檔案儲存到磁碟。

def store_cache(json_file, cache):
  with open(json_file, 'w') as file:
        json.dump(cache, file)

這些函式將在 SemanticCache 類中使用,該類包括搜尋函式及其初始化函式。

儘管 ask 函式有大量程式碼,但其目的卻非常簡單。它在快取中查詢與使用者剛剛提出的問題最接近的問題。

然後,檢查它是否在指定的閾值內。如果是肯定的,它直接從快取中返回響應;否則,它呼叫 query_database 函式從 ChromaDB 中檢索資料。

我使用了歐幾里得距離而不是餘弦,後者在向量比較中被廣泛使用。這種選擇是基於歐幾里得距離是 Faiss 使用的預設度量這一事實。雖然也可以計算餘弦距離,但這樣做會增加複雜性,可能不會對最終結果產生重大影響。

class semantic_cache:
  def __init__(self, json_file="cache_file.json", thresold=0.35):
      # Initialize Faiss index with Euclidean distance
      self.index, self.encoder = init_cache()

      # Set Euclidean distance threshold
      # a distance of 0 means identicals sentences
      # We only return from cache sentences under this thresold
      self.euclidean_threshold = thresold

      self.json_file = json_file
      self.cache = retrieve_cache(self.json_file)

  def ask(self, question: str) -> str:
      # Method to retrieve an answer from the cache or generate a new one
      start_time = time.time()
      try:
          #First we obtain the embeddings corresponding to the user question
          embedding = self.encoder.encode([question])

          # Search for the nearest neighbor in the index
          self.index.nprobe = 8
          D, I = self.index.search(embedding, 1)

          if D[0] >= 0:
              if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                  row_id = int(I[0][0])

                  print('Answer recovered from Cache. ')
                  print(f'{D[0][0]:.3f} smaller than {self.euclidean_threshold}')
                  print(f'Found cache in row: {row_id} with score {D[0][0]:.3f}')
                  print(f'response_text: ' + self.cache['response_text'][row_id])

                  end_time = time.time()
                  elapsed_time = end_time - start_time
                  print(f"Time taken: {elapsed_time:.3f} seconds")
                  return self.cache['response_text'][row_id]

          # Handle the case when there are not enough results
          # or Euclidean distance is not met, asking to chromaDB.
          answer  = query_database([question], 1)
          response_text = answer['documents'][0][0]

          self.cache['questions'].append(question)
          self.cache['embeddings'].append(embedding[0].tolist())
          self.cache['answers'].append(answer)
          self.cache['response_text'].append(response_text)

          print('Answer recovered from ChromaDB. ')
          print(f'response_text: {response_text}')

          self.index.add(embedding)
          store_cache(self.json_file, self.cache)
          end_time = time.time()
          elapsed_time = end_time - start_time
          print(f"Time taken: {elapsed_time:.3f} seconds")

          return response_text
      except Exception as e:
          raise RuntimeError(f"Error during 'ask' method: {e}")

6.1 測試semantic_cache類。

# Initialize the cache.
cache = semantic_cache('4cache_file.json')
Index trained
modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]
config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]
README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]
sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]
config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]
pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]
tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]
vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]
tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]
special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]
1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]
results = cache.ask("How work a vaccine?")
Answer recovered from ChromaDB. 
response_text: Summary : Shots may hurt a little, but the diseases they can prevent are a lot worse. Some are even life-threatening. Immunization shots, or vaccinations, are essential. They protect against things like measles, mumps, rubella, hepatitis B, polio, tetanus, diphtheria, and pertussis (whooping cough). Immunizations are important for adults as well as children.    Your immune system helps your body fight germs by producing substances to combat them. Once it does, the immune system "remembers" the germ and can fight it again. Vaccines contain germs that have been killed or weakened. When given to a healthy person, the vaccine triggers the immune system to respond and thus build immunity.     Before vaccines, people became immune only by actually getting a disease and surviving it. Immunizations are an easier and less risky way to become immune.     NIH: National Institute of Allergy and Infectious Diseases
Time taken: 0.655 seconds

正如預期的那樣,該響應已從 ChromaDB 獲得。然後該類將其儲存在快取中。

現在,如果我們傳送第二個完全不同的問題,也應該從 ChromaDB 檢索響應。這是因為之前儲存的問題非常不同,以至於它會超出歐幾里得距離的指定閾值。

results = cache.ask("Explain briefly what is a Periodic Paralyses")
Answer recovered from ChromaDB. 
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.
Time taken: 0.083 seconds

完美,語義快取系統的表現符合預期。

讓我們繼續用一個與我們剛剛提出的問題非常相似的問題來測試它。

在這種情況下,響應應該直接來自快取,而無需訪問 ChromaDB 資料庫。

results = cache.ask("Briefly explain me what is a periodic paralyses")
Answer recovered from Cache. 
0.018 smaller than 0.35
Found cache in row: 1 with score 0.018
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.
Time taken: 0.015 seconds

這兩個問題非常相似,它們的歐幾里得距離非常小,幾乎就像是完全相同的。

現在,讓我們嘗試另一個問題,這次問題更加明顯,並觀察系統的行為。

question_def = "Write in 20 words what is a periodic paralyses"
results = cache.ask(question_def)
Answer recovered from Cache. 
0.220 smaller than 0.35
Found cache in row: 1 with score 0.220
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.
Time taken: 0.017 seconds

我們觀察到歐氏距離有所增加,但仍在指定的閾值內。因此,它繼續直接從快取中返回響應。

7. 載入模型並建立提示

是時候使用庫 transformers 了,這是 hugging face 最著名的用於處理語言模型的庫。

我們正在匯入:

Autotokenizer:它是一個實用程式類,用於標記與各種預訓練語言模型相容的文字輸入。

AutoModelForCasualLLM:它提供了一個預訓練語言模型的介面,該模型專門為使用因果語言建模(例如 GPT 模型)的語言生成任務而設計,或者本筆記本中使用的模型 *Gemma-2b-it。

所選模型是 Gemma-2b-it。

請隨意測試不同的模型,您需要搜尋針對文字生成訓練的 NLP 模型。

!pip install torch
from torch import cuda, torch
#In a MAC Silicon the device must be 'mps'
# device = torch.device('mps') #to use with MAC Silicon
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

from transformers import AutoTokenizer, AutoModelForCausalLM

#model_id = "databricks/dolly-v2-3b"
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             device_map="cuda",
                                            torch_dtype=torch.bfloat16)
tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]
tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]
tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]
special_tokens_map.json:   0%|          | 0.00/888 [00:00<?, ?B/s]
config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]
model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]
Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]
model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]
model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

8. 建立擴充套件提示

為了建立提示,我們使用查詢“semantic_cache”類的結果和使用者提出的問題。

提示有兩個部分,相關上下文(即從資料庫中恢復的資訊)和使用者的問題。

我們只需要將這兩個部分放在一起即可建立提示,然後將其傳送給模型。

prompt_template = f"Relevant context: {results}\n\n The user's question: {question_def}"
prompt_template
"Relevant context: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.\n                \nThe two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.\n\n The user's question: Write in 20 words what is a periodic paralyses"
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")

現在剩下的就是將提示傳送給模型並等待它的響應!

outputs = model.generate(**input_ids,
                         max_new_tokens=256)
print(tokenizer.decode(outputs[0]))
<bos>Relevant context: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in potassium levels in the blood. Attacks often begin in infancy or early childhood and are precipitated by rest after exercise or by fasting. Attacks are usually shorter, more frequent, and less severe than the hypokalemic form. Muscle spasms are common.

 The user's question: Write in 20 words what is a periodic paralyses?

Answer: A group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells.<eos>

9. 結論

訪問 ChromaDB 和直接訪問快取之間的效能提升約為 50%。然而,在較大的專案中,這種差異會增加,從而導致 90-95% 的提升。

我們在 Chroma 中的資料很少,並且只有一個快取類例項。通常,快取系統背後的資料要大得多,可能涉及的不僅僅是對向量資料庫的查詢,而是來自不同的地方。

通常有多個快取類例項,通常基於使用者型別,因為問題往往會在具有共同特徵的使用者中重複更多。

總之,我們建立了一個非常簡單的 RAG(檢索增強生成)系統,並在使用者的問題和獲取建立豐富提示所需的資訊之間新增了一個語義快取層來增強它。

相關文章