進一寸有進一寸的歡喜,談談如何優化 Milvus 資料庫的向量查詢功能

Zilliz發表於2021-11-11

✏️ 編者按

每年暑期,Milvus 社群都會攜手中科院軟體所,在「開源之夏」活動中為高校學生們準備豐富的工程專案,並安排導師答疑解惑。張煜旻同學在「開源之夏」活動中表現優秀,他相信進一寸有進一寸的歡喜,嘗試在貢獻開源的過程中超越自我。

他的專案為 Milvus 資料庫的向量查詢操作提供精度控制,能讓開發者自定義返回精度,在減少記憶體消耗的同時,提高了返回結果的可讀性。

想要了解更多優質開源專案和專案經驗分享?請戳:有哪些值得參與的開源專案?

專案簡介

專案名稱:支援指定搜尋時返回的距離精度

學生簡介:張煜旻,中國科學院大學電子資訊軟體工程專業碩士在讀

專案導師:Zilliz 軟體工程師張財

導師評語: 張煜旻同學優化了 Milvus 資料庫的查詢功能,使其在搜尋時可以用指定精度去進行查詢,使搜尋過程更靈活,使用者可以根據自己的需求用不同的精度進行查詢,給使用者帶來了便利。

支援指定搜尋時返回的距離精度

任務簡介

在進行向量查詢時,搜尋請求返回 id 和 distance 欄位,其中的 distance 欄位型別是浮點數。Milvus 資料庫所計算的距離是一個 32 位浮點數,但是 Python SDK 返回並以 64 位浮點顯示它,導致某些精度無效。本專案的貢獻是,支援指定搜尋時返回的距離精度,解決了在 Python 端顯示時部分精度無效的情況,並減少部分記憶體開銷。

專案目標

  • 解決計算結果和顯示精度不匹配的問題
  • 支援搜尋時返回指定的距離精度
  • 補充相關文件

專案步驟

  • 前期調研,理解 Milvus 整體框架
  • 明確各模組之間的呼叫關係
  • 設計解決方案和確認結果

專案綜述

什麼是 Milvus 資料庫?

Milvus 是一款開源向量資料庫,賦能 AI 應用和向量相似度搜尋。在系統設計上, Milvus 資料庫的前端有方便使用者使用的 Python SDK(Client);在 Milvus 資料庫的後端,整個系統分為了接入層(Access Layer)、協調服務(Coordinator Server)、執行節點(Worker Node)和儲存服務(Storge)四個層面:

(1)接入層(Access Layer):系統的門面,包含了一組對等的 Proxy 節點。接入層是暴露給使用者的統一 endpoint,負責轉發請求並收集執行結果。

(2)協調服務(Coordinator Service):系統的大腦,負責分配任務給執行節點。共有四類協調者角色:root 協調者、data 協調者、query 協調者和 index 協調者。

(3)執行節點(Worker Node):系統的四肢,執行節點只負責被動執行協調服務發起的讀寫請求。目前有三類執行節點:data 節點、query 節點和 index 節點。

(4)儲存服務(Storage):系統的骨骼,是所有其他功能實現的基礎。Milvus  資料庫依賴三類儲存:後設資料儲存、訊息儲存(log broker)和物件儲存。從語言角度來看,則可以看作三個語言層,分別是 Python 構成的 SDK 層、Go 構成的中間層和 C++ 構成的核心計算層。

Milvus 資料庫的架構圖

向量查詢 Search 時,到底發生了什麼?

在 Python SDK 端,當使用者發起一個 Search API 呼叫時,這個呼叫會被封裝成 gRPC 請求併傳送給 Milvus 後端,同時 SDK 開始等待。而在後端,Proxy 節點首先接受了從 Python SDK 傳送過來的請求,然後會對接受的請求進行處理,最後將其封裝成 message,經由 Producer 傳送到消費佇列中。當訊息被髮送到消費佇列後,Coordinator 將會對其進行協調,將資訊傳送到合適的 query node 中進行消費。而當 query node 接收到訊息後,則會對訊息進行進一步的處理,最後將資訊傳遞給由 C++ 構成的計算層。在計算層,則會根據不同的情形,呼叫不同的計算函式對向量間的距離進行計算。當計算完成後,結果則會依次向上傳遞,直到到達 SDK 端。

解決方案設計

通過前文簡單介紹,我們對向量查詢的過程有了一個大致的概念。同時,我們也可以清楚地認識到,為了完成查詢目標,我們需要對 Python 構成的 SDK 層、Go 構成的中間層和 C++ 構成的計算層都進行修改,修改方案如下:

1. 在 Python 層中的修改步驟:

為向量查詢 Search 請求新增一個 round_decimal 引數,從而確定返回的精度資訊。同時,需要對引數進行一些合法性檢查和異常處理,從而構建 gRPC 的請求:

round_decimal = param_copy("round_decimal", 3)
if not isinstance(round_decimal, (int, str))
  raise ParamError("round_decimal must be int or str")
try:
  round_decimal = int(round_decimal)
except Exception:
  raise ParamError("round_decimal is not illegal")

if round_decimal < 0 or round_decimal > 6:
  raise ParamError("round_decimal must be greater than zero and less than seven")
if not instance(params, dict):
  raise ParamError("Search params must be a dict")
search_params = {"anns_field": anns_field, "topk": limit, "metric_type": metric_type, "params": params, "round_decimal": round_decimal}

2. 在 Go 層中的修改步驟:

在 task.go 檔案中新增 RoundDecimalKey 這個常量,保持風格統一併方便後續調取:

const (
 InsertTaskName                  = "InsertTask"
 CreateCollectionTaskName        = "CreateCollectionTask"
 DropCollectionTaskName          = "DropCollectionTask"
 SearchTaskName                  = "SearchTask"
 RetrieveTaskName                = "RetrieveTask"
 QueryTaskName                   = "QueryTask"
 AnnsFieldKey                    = "anns_field"
 TopKKey                         = "topk"
 MetricTypeKey                   = "metric_type"
 SearchParamsKey                 = "params"
 RoundDecimalKey                 = "round_decimal"
 HasCollectionTaskName           = "HasCollectionTask"
 DescribeCollectionTaskName      = "DescribeCollectionTask"

接著,修改 PreExecute 函式,獲取 round_decimal 的值,構建 queryInfo 變數,並新增異常處理:

searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, st.query.SearchParams)
        if err != nil {
            return errors.New(SearchParamsKey + " not found in search_params")
        }
        roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, st.query.SearchParams)
        if err != nil {
            return errors.New(RoundDecimalKey + "not found in search_params")
        }
        roundDeciaml, err := strconv.Atoi(roundDecimalStr)
        if err != nil {
            return errors.New(RoundDecimalKey + " " + roundDecimalStr + " is not invalid")
        }

        queryInfo := &planpb.QueryInfo{
            Topk:         int64(topK),
            MetricType:   metricType,
            SearchParams: searchParams,
            RoundDecimal: int64(roundDeciaml),
        }

同時,修改 query 的 proto 檔案,為 QueryInfo 新增 round_decimal 變數:

message QueryInfo {
int64 topk = 1;
string metric_type = 3;
string search_params = 4;
int64 round_decimal = 5;
}

3. 在 C++ 層中的修改步驟:

在 SearchInfo 結構體中新增新的變數 round\_decimal\_ ,從而接受 Go 層傳來的 round_decimal 值:

struct SearchInfo {
   int64_t topk_;
   int64_t round_decimal_;
   FieldOffset field_offset_;
   MetricType metric_type_;
   nlohmann::json search_params_;
};

在 ParseVecNode 和 PlanNodeFromProto 函式中,SearchInfo 結構體需要接受 Go 層中 round_decimal 值:

std::unique_ptr<VectorPlanNode>
Parser::ParseVecNode(const Json& out_body) {
    Assert(out_body.is_object());
    Assert(out_body.size() == 1);
    auto iter = out_body.begin();
    auto field_name = FieldName(iter.key());

    auto& vec_info = iter.value();
    Assert(vec_info.is_object());
    auto topk = vec_info["topk"];
    AssertInfo(topk > 0, "topk must greater than 0");
    AssertInfo(topk < 16384, "topk is too large");

    auto field_offset = schema.get_offset(field_name);

    auto vec_node = [&]() -> std::unique_ptr<VectorPlanNode> {
        auto& field_meta = schema.operator[](field_name);
        auto data_type = field_meta.get_data_type();
        if (data_type == DataType::VECTOR_FLOAT) {
            return std::make_unique<FloatVectorANNS>();
        } else {
            return std::make_unique<BinaryVectorANNS>();
        }
    }();
    vec_node->search_info_.topk_ = topk;
    vec_node->search_info_.metric_type_ = GetMetricType(vec_info.at("metric_type"));
    vec_node->search_info_.search_params_ = vec_info.at("params");
    vec_node->search_info_.field_offset_ = field_offset;
    vec_node->search_info_.round_decimal_ = vec_info.at("round_decimal");
    vec_node->placeholder_tag_ = vec_info.at("query");
    auto tag = vec_node->placeholder_tag_;
    AssertInfo(!tag2field_.count(tag), "duplicated placeholder tag");
    tag2field_.emplace(tag, field_offset);
    return vec_node;
}
std::unique_ptr<VectorPlanNode>
ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
    // TODO: add more buffs
    Assert(plan_node_proto.has_vector_anns());
    auto& anns_proto = plan_node_proto.vector_anns();
    auto expr_opt = [&]() -> std::optional<ExprPtr> {
        if (!anns_proto.has_predicates()) {
            return std::nullopt;
        } else {
            return ParseExpr(anns_proto.predicates());
        }
    }();

    auto& query_info_proto = anns_proto.query_info();

    SearchInfo search_info;
    auto field_id = FieldId(anns_proto.field_id());
    auto field_offset = schema.get_offset(field_id);
    search_info.field_offset_ = field_offset;

    search_info.metric_type_ = GetMetricType(query_info_proto.metric_type());
    search_info.topk_ = query_info_proto.topk();
    search_info.round_decimal_ = query_info_proto.round_decimal();
    search_info.search_params_ = json::parse(query_info_proto.search_params());

    auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
        if (anns_proto.is_binary()) {
            return std::make_unique<BinaryVectorANNS>();
        } else {
            return std::make_unique<FloatVectorANNS>();
        }
    }();
    plan_node->placeholder_tag_ = anns_proto.placeholder_tag();
    plan_node->predicate_ = std::move(expr_opt);
    plan_node->search_info_ = std::move(search_info);
    return plan_node;
}

在 SubSearchResult 類新增新的成員變數 round_decimal,同時修改每一處的 SubSearchResult 變數宣告:

class SubSearchResult {
public:
   SubSearchResult(int64_t num_queries, int64_t topk, MetricType metric_type)
       : metric_type_(metric_type),
         num_queries_(num_queries),
         topk_(topk),
         labels_(num_queries * topk, -1),
         values_(num_queries * topk, init_value(metric_type)) {
   }

在 SubSearchResult 類新增一個新的成員函式,以便最後對每一個結果進行四捨五入精度控制:

void
SubSearchResult::round_values() {
    if (round_decimal_ == -1)
        return;
    const float multiplier = pow(10.0, round_decimal_);
    for (auto it = this->values_.begin(); it != this->values_.end(); it++) {
        *it = round(*it * multiplier) / multiplier;
    }
}

為 SearchDataset 結構體新增新的變數 round_decimal,同時修改每一處的 SearchDataset 變數宣告:

struct SearchDataset {
    MetricType metric_type;
    int64_t num_queries;
    int64_t topk;
    int64_t round_decimal;
    int64_t dim;
    const void* query_data;
};

修改 C++ 層中各個距離計算函式(FloatSearch、BinarySearchBruteForceFast 等等),使其接受 round_decomal 值:

Status
FloatSearch(const segcore::SegmentGrowingImpl& segment,
            const query::SearchInfo& info,
            const float* query_data,
            int64_t num_queries,
            int64_t ins_barrier,
            const BitsetView& bitset,
            SearchResult& results) {
    auto& schema = segment.get_schema();
    auto& indexing_record = segment.get_indexing_record();
    auto& record = segment.get_insert_record();
    // step 1: binary search to find the barrier of the snapshot
    // auto del_barrier = get_barrier(deleted_record_, timestamp);

#if 0
    auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier);
    Assert(bitmap_holder);
    auto bitmap = bitmap_holder->bitmap_ptr;
#endif

    // step 2.1: get meta
    // step 2.2: get which vector field to search
    auto vecfield_offset = info.field_offset_;
    auto& field = schema[vecfield_offset];

    AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT");
    auto dim = field.get_dim();
    auto topk = info.topk_;
    auto total_count = topk * num_queries;
    auto metric_type = info.metric_type_;
    auto round_decimal = info.round_decimal_;
    // step 3: small indexing search
    // std::vector<int64_t> final_uids(total_count, -1);
    // std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
    SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal);
    dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};
    auto vec_ptr = record.get_field_data<FloatVector>(vecfield_offset);

    int current_chunk_id = 0;
SubSearchResult
BinarySearchBruteForceFast(MetricType metric_type,
                           int64_t dim,
                           const uint8_t* binary_chunk,
                           int64_t size_per_chunk,
                           int64_t topk,
                           int64_t num_queries,
                           int64_t round_decimal,
                           const uint8_t* query_data,
                           const faiss::BitsetView& bitset) {
    SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
    float* result_distances = sub_result.get_values();
    idx_t* result_labels = sub_result.get_labels();

    int64_t code_size = dim / 8;
    const idx_t block_size = size_per_chunk;

    raw_search(metric_type, binary_chunk, size_per_chunk, code_size, num_queries, query_data, topk, result_distances,
               result_labels, bitset);
    sub_result.round_values();
    return sub_result;
}

結果確認

1. 對 Milvus 資料庫進行重新編譯:

2. 啟動環境容器:

3. 啟動 Milvus 資料庫:

4.構建向量查詢請求:

5. 確認結果,預設保留 3 位小數,0 捨去:

總結和感想

參加這次的夏季開源活動,對 我來說是非常寶貴的經歷。在這次活動中,我第一次嘗試閱讀開源專案程式碼,第一次嘗試接觸多語言構成的專案,第一次接觸到 Make、gRPc、pytest 等等。在編寫程式碼和測試程式碼階段,我也遇到來許多意想不到的問題,例如,「奇奇怪怪」的依賴問題、由於 Conda 環境導致的編譯失敗問題、測試無法通過等等。面對這些問題,我 漸漸學會耐心細心地檢視報錯日誌,積極思考、檢查程式碼並進行測試,一步一步縮小錯誤範圍,定位錯誤程式碼並嘗試各種解決方案。

通過這次的活動 ,我吸取 了 許 多 經驗和教訓,同時也十分感謝張財導師,感謝他在我開發過程中 耐心地幫我答疑解惑、指導方向 !同時, 希望大家能多多關注 Milvus 社群 , 相信 一定能夠有所收穫!

最後,歡迎大家多多與我交流(? deepmin@mail.deepexplore.top ),我主要的研究方向是自然語言處理,平時喜歡看科幻小說、動畫和折騰伺服器個人網站,每日閒逛 Stack Overflow 和GitHub。我相信進一寸有進一寸的歡喜,希望能和你一起共同進步。

相關文章