作者:京東科技 賈世聞
RAG(Retrieval-Augmented Generation)技術在AI生態系統中扮演著至關重要的角色,特別是在提升大型語言模型(LLMs)的準確性和應用範圍方面。RAG透過結合檢索技術與LLM提示,從各種資料來源檢索相關資訊,並將其與使用者的問題結合,生成準確且豐富的回答。這一機制特別適用於需要應對資訊不斷更新的場景,因為大語言模型所依賴的引數知識本質上是靜態的。
RAG技術的優勢在於它能夠利用外部知識庫,引用大量的資訊,以提供更深入、準確且有價值的答案,提高了生成文字的可靠性。此外,RAG模型具備檢索庫的更新機制,可以實現知識的即時更新,無需重新訓練模型,這在及時性要求高的應用中佔優勢。
目前構建一個RAG並不是一個非常的事情。使用Langchain等成熟技術架構百十行程式碼就能構建一個Demo。那能不能利用目前的Rust生態構建一個簡易的RAG。說幹就幹,本期和大家聊聊如果使用rust語言構建rag。
構建知識庫
知識庫構建主要是模型+向量庫,為了保證所有系統中所有元件都使用rust構建,在限量資料庫的選型上我們使用qdrant,純rust構建的向量資料庫。
知識庫的構建最重要的步驟是embedding的過程。
過程如下:
- 模型載入
- 獲取文字token
- 透過模型獲取文字的Embedding
下面詳細介紹每個過程細節及程式碼實現。
模型載入
以下程式碼用於載入模型和tokenizer
async fn build_model_and_tokenizer(model_config: &ConfigModel) -> Result<(BertModel, Tokenizer)> {
let device = Device::new_cuda(0)?;
let repo = Repo::with_revision(
model_config.model_id.clone(),
RepoType::Model,
model_config.revision.clone(),
);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = ApiBuilder::new()
.build()?;
let api = api.repo(repo);
let config = api.get("config.json").await?;
let tokenizer = api.get("tokenizer.json").await?;
let weights = if model_config.use_pth {
api.get("pytorch_model.bin").await?
} else {
api.get("model.safetensors").await?
};
(config, tokenizer, weights)A
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if model_config.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
if model_config.approximate_gelu {
config.hidden_act = HiddenAct::GeluApproximate;
}
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
模型和tokenizer是系統中頻繁呼叫的部分,所以為了避免重複載入,透過OnceCell構建靜態全域性變數
pub static GLOBAL_EMBEDDING_MODEL: OnceCell> = OnceCell::const_new();
pub async fn init_model_and_tokenizer() -> Arc<(BertModel, Tokenizer)> {
let config = get_config().unwrap();
let (m, t) = build_model_and_tokenizer(&config.model).await.unwrap();
Arc::new((m, t))
}
在系統啟動時載入模型
GLOBAL_RUNTIME.block_on(async {
log::info!("global runtime start!");
// 載入model
GLOBAL_EMBEDDING_MODEL
.get_or_init(init_model_and_tokenizer)
.await;
});
Embedding 過程主要由一下函式實現。
pub async fn embedding_setence(content: &str) -> Result>> {
let m_t = GLOBAL_EMBEDDING_MODEL.get().unwrap();
let tokens = m_t
.1
.encode(content, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &m_t.0.device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let sequence_output = m_t.0.forward(&token_ids, &token_type_ids)?;
let (_n_sentence, n_tokens, _hidden_size) = sequence_output.dims3()?;
let embeddings = (sequence_output.sum(1)? / (n_tokens as f64))?;
let embeddings = normalize_l2(&embeddings)?;
let encodings = embeddings.to_vec2::()?;
Ok(encodings)
}
函式透過tokenizer encode輸入的文字,再使用模型embed token 獲取一個三維的Tensor,最後歸一化張量。
資料入庫
知識庫構建是將待檢索文字向量化後儲存到向量資料庫的過程。
本次使用京東雲文件作為原始文字,加工為以下格式。資料加工過程這裡就不累述了。
{
"content": "# 服務計費\n\n主機遷移服務自身為免費服務,但是遷移目標為雲主機映象時,遷移過程依賴系統自動建立的 中轉資源的配合,這些資源中涉及部分付費資源,會產生相應費用。\n\n遷移過程涉及的中轉資付費資源配置及計費說明如下(單個遷移任務):\n\n| | 雲主機 | 雲硬碟 | 彈性公網IP |\n| --- | --- | --- | ------ |\n| 計費型別 | 按配置 | 按配置 | 按用量 |\n| 規格配置 | 2C4G (c.n2.large或c.n3.large或c.n1.large) | 系統盤:40G 通用型SSD 資料盤:通用型SSD,數量及容量取決於源伺服器系統盤及資料盤情況 | 30Mbps |\n| 費用預估 | 雲主機規格每小時價格\\*遷移時長 | 雲硬碟規格每小時價格\\*遷移時長 | 彈性公網IP每小時保有費\\*遷移時長 僅使用彈性公網IP入方向流量,只涉及IP保有用,不涉及流量費用 |\n\n> 提示:\n>\n> * 遷移時長取決於源伺服器遷資料量以及源伺服器公網出方向頻寬,公網連線順暢且源伺服器公網出方向頻寬不低於22.5Mbps的情況下(主機遷移為單執行緒傳輸,京東云云主機在單流傳輸下實際頻寬為頻寬上限的75%左右),實際資料容量為5GB的磁碟遷移時長在30分鐘左右。\n> * 中轉例項例項繫結的安全組出方向預設拒絕所有流量,因此預設情況下降不會產生任何公網出方向收費流量,但此配置也影響了雲主機部分監控指標的上報,如需要監控中轉例項的全部監控資料,可自行調整安全組規則方向出方向443埠。",
"title": "服務計費說明",
"product": "雲主機 CVM",
"url": "https://docs.jdcloud.com/cn/virtual-machines/server-migration-service/billing"
}
入庫完整程式碼如下:
use anyhow::Error as E;
use anyhow::Result;
use candle_core::Device;
use candle_core::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use hf_hub::{api::tokio::Api, Repo, RepoType};
use qdrant_client::qdrant::CollectionExistsRequest;
use qdrant_client::qdrant::CreateCollectionBuilder;
use qdrant_client::qdrant::DeleteCollection;
use qdrant_client::qdrant::Distance;
use qdrant_client::qdrant::UpsertPointsBuilder;
use qdrant_client::qdrant::VectorParamsBuilder;
use qdrant_client::Payload;
use qdrant_client::{
qdrant::{
CollectionOperationResponse, CreateCollection, PointStruct, PointsOperationResponse,
UpsertPoints,
},
Qdrant,
};
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use std::fs;
use std::sync::Arc;
use tokenizers::Tokenizer;
use tokio::sync::OnceCell;
use uuid::Uuid;
use walkdir::WalkDir;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Doc {
pub content: String,
pub title: String,
pub product: String,
pub url: String,
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
pub struct ModelConfig {
#[serde(default = "ModelConfig::model_id_default")]
pub model_id: String,
#[serde(default = "ModelConfig::revision_default")]
pub revision: String,
#[serde(default = "ModelConfig::use_pth_default")]
pub use_pth: bool,
#[serde(default = "ModelConfig::approximate_gelu_default")]
pub approximate_gelu: bool,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
model_id: Self::model_id_default(),
revision: Self::revision_default(),
use_pth: Self::use_pth_default(),
approximate_gelu: Self::approximate_gelu_default(),
}
}
}
impl ModelConfig {
fn model_id_default() -> String {
"moka-ai/m3e-large".to_string()
}
fn revision_default() -> String {
"main".to_string()
}
fn use_pth_default() -> bool {
true
}
fn approximate_gelu_default() -> bool {
false
}
}
pub static GLOBAL_MODEL: OnceCell> = OnceCell::const_new();
pub static GLOBAL_TOKEN: OnceCell> = OnceCell::const_new();
pub async fn init_model() -> Arc {
let config = ModelConfig::default();
let (m, _) = build_model_and_tokenizer(&config).await.unwrap();
Arc::new(m)
}
pub async fn init_tokenizer() -> Arc {
let config = ModelConfig::default();
let (_, t) = build_model_and_tokenizer(&config).await.unwrap();
Arc::new(t)
}
async fn build_model_and_tokenizer(model_config: &ModelConfig) -> Result<(BertModel, Tokenizer)> {
let device = Device::new_cuda(0)?;
let repo = Repo::with_revision(
model_config.model_id.clone(),
RepoType::Model,
model_config.revision.clone(),
);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json").await?;
let tokenizer = api.get("tokenizer.json").await?;
let weights = if model_config.use_pth {
api.get("pytorch_model.bin").await?
} else {
api.get("model.safetensors").await?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if model_config.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
if model_config.approximate_gelu {
config.hidden_act = HiddenAct::GeluApproximate;
}
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
pub async fn embedding_setence(content: &str) -> Result>> {
let m = GLOBAL_MODEL.get().unwrap();
let t = GLOBAL_TOKEN.get().unwrap();
let tokens = t.encode(content, true).map_err(E::msg)?.get_ids().to_vec();
let token_ids = Tensor::new(&tokens[..], &m.device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let sequence_output = m.forward(&token_ids, &token_type_ids)?;
let (_n_sentence, n_tokens, _hidden_size) = sequence_output.dims3()?;
let embeddings = (sequence_output.sum(1).unwrap() / (n_tokens as f64))?;
let embeddings = normalize_l2(&embeddings).unwrap();
let encodings = embeddings.to_vec2::()?;
Ok(encodings)
}
pub fn normalize_l2(v: &Tensor) -> Result {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}
pub struct QdrantClient {
client: Qdrant,
}
impl QdrantClient {
pub async fn create_collection(
&self,
request: impl Into,
) -> Result {
let resp = self.client.create_collection(request).await?;
Ok(resp)
}
pub async fn delete_collection(
&self,
request: impl Into,
) -> Result {
let resp = self.client.delete_collection(request).await?;
Ok(resp)
}
pub async fn collection_exists(
&self,
request: impl Into,
) -> Result {
let resp = self.client.collection_exists(request).await?;
Ok(resp)
}
pub async fn load_dir(&self, path: &str, collection_name: &str) {
let mut points = vec![];
for entry in WalkDir::new(path)
.into_iter()
.filter_map(Result::ok)
.filter(|e| !e.file_type().is_dir() && e.file_name().to_str().is_some())
{
if let Some(p) = entry.path().to_str() {
let id = Uuid::new_v4();
let content = match fs::read_to_string(p) {
Ok(c) => c,
Err(_) => continue,
};
let doc = match from_str::(content.as_str()) {
Ok(d) => d,
Err(_) => continue,
};
let mut payload = Payload::new();
payload.insert("content", doc.content);
payload.insert("title", doc.title);
payload.insert("product", doc.product);
payload.insert("url", doc.url);
let vector_contens = embedding_setence(content.as_str()).await.unwrap();
let ps = PointStruct::new(id.to_string(), vector_contens[0].clone(), payload);
points.push(ps);
if points.len().eq(&100) {
let p = points.clone();
self.client
.upsert_points(UpsertPointsBuilder::new(collection_name, p).wait(true))
.await
.unwrap();
points.clear();
println!("batch finish");
}
}
}
if points.len().gt(&0) {
self.client
.upsert_points(UpsertPointsBuilder::new(collection_name, points).wait(true))
.await
.unwrap();
}
}
}
#[tokio::main]
async fn main() {
// 載入模型
GLOBAL_MODEL.get_or_init(init_model).await;
GLOBAL_TOKEN.get_or_init(init_tokenizer).await;
let collection_name = "default_collection";
// The Rust client uses Qdrant's GRPC interface
let qdrant = Qdrant::from_url("http://localhost:6334").build().unwrap();
let qdrant_client = QdrantClient { client: qdrant };
if !qdrant_client
.collection_exists(collection_name)
.await
.unwrap()
{
qdrant_client
.create_collection(
CreateCollectionBuilder::new(collection_name)
.vectors_config(VectorParamsBuilder::new(1024, Distance::Dot)),
)
.await