並行化最佳化KD樹演算法:使用C#實現高效的最近鄰搜尋

程序设计实验室發表於2024-03-10

本文資訊

中文名:《並行化最佳化KD樹演算法:使用C#實現高效的最近鄰搜尋》

英文名:"Parallelized Optimization of KD-Tree Algorithm: Implementing Efficient Nearest Neighbor Search in C#"

摘要

本文介紹瞭如何使用平行計算技術最佳化 KD 樹演算法,並使用 C# 程式語言實現了高效的最近鄰搜尋。首先,我們簡要介紹了 KD 樹的原理和構建過程,然後詳細討論瞭如何利用平行計算庫在多個 CPU 核心上並行構建 KD 樹,從而加速搜尋過程。透過實驗驗證,我們證明了並行化最佳化能夠顯著提高 KD 樹的構建速度和搜尋效率,為大規模資料集下的最近鄰搜尋問題提供了一種高效的解決方案。

Summary

This article presents a parallelized optimization approach for KD-tree algorithm and demonstrates efficient nearest neighbor search implementation in C#. We first introduce the principles and construction process of KD trees, and then discuss in detail how to leverage parallel computing techniques to build KD trees concurrently on multiple CPU cores, thus accelerating the search process. Through experimental validation, we prove that parallelized optimization significantly improves the construction speed and search efficiency of KD trees, providing an efficient solution for nearest neighbor search problems on large-scale datasets.

版本資訊

本文涉及到的 C# 程式碼使用 .Net 8.0 以及 C# 12 版本編寫。

前言

思考以下場景:有 1000 個 A 型別的地點(包含地址和GPS座標),以及 50000 個 B 型別的地點,需要找出距離每個 A 型別地點最近的 B 型別地點。

我第一時間想到的是之前做推薦系統用過的 KNN 演算法,不過 KNN 的實現計算量太大了,當資料量多的時候,需要耗費很多時間,因此針對這個場景,我採用了 KD 樹演算法。

關於 KD 樹

KD 樹(K-Dimensional Tree)是一種用於多維空間中的資料結構,它是二叉樹的一種變種,用於高效地組織和搜尋多維資料。同時也是是 KNN 的一個高效演算法,KD 樹的主要優點是可以在高維空間中進行快速的最近鄰搜尋和範圍搜尋。

關於 KD 樹的詳細原理不在本文的討論範圍內,本文只做簡單介紹。本文討論的場景是關於 GPS 座標間距離的計算,因此選擇 KD 樹的維度是二維。

1. 資料結構

  • 節點: KD 樹中的每個節點代表一個資料點。節點包含一個軸(Axis)和一個分割值(Split Value),以及對應於左子樹和右子樹的指標。

  • 根節點: KD 樹的根節點代表整個資料集的範圍。

  • 葉子節點: KD 樹的葉子節點代表一個單獨的資料點。

2. 構建過程

KD 樹的構建過程基於遞迴的分割策略。通常,我們會選擇一個軸(Axis)和一個分割值(Split Value),將資料集分割成兩個子集。然後,遞迴地對每個子集進行相同的分割操作,直到每個子集中的資料點數量小於某個閾值,或者達到了指定的深度。

構建過程中,通常採用以下策略來選擇軸和分割值:

  • 軸的選擇: 軸的選擇通常是根據資料點在每個維度上的方差或者範圍來確定。可以選擇方差最大的維度作為軸,或者按照輪換的方式選擇每個維度作為軸。

  • 分割值的選擇: 分割值通常是選取當前子集中資料點在選定軸上的中位數。

3. 搜尋過程

KD 樹的搜尋過程也是基於遞迴的。搜尋過程從根節點開始,按照某種規則向下遍歷樹,直到找到目標資料點或者達到葉子節點。

搜尋過程中,通常採用以下策略來確定搜尋順序:

  • 確定分割方向: 根據目標資料點在當前節點所選定的軸上的值,確定搜尋方向。如果目標值小於當前節點的分割值,則向左子樹搜尋;否則向右子樹搜尋。

  • 確定搜尋順序: 根據目標資料點在當前節點所選定的軸上的距離,確定搜尋順序。首先搜尋距離更近的子樹,然後再搜尋距離更遠的子樹。

  • 剪枝策略: 在搜尋過程中,可以採用剪枝策略來減少搜尋的分支,提高搜尋效率。例如,可以計算目標點與當前節點分割超平面的距離,如果距離大於當前最近距離,則可以剪掉該分支。

4. 應用場景

KD 樹常用於高維空間中的資料組織和搜尋,特別是在機器學習和資料探勘領域中。常見的應用包括最近鄰搜尋、範圍搜尋、近似最近鄰搜尋等。

5. 總結

KD 樹是一種高效的多維空間資料結構,適用於快速的最近鄰搜尋和範圍搜尋。它的構建和搜尋過程都基於遞迴的思想,並且可以透過選擇合適的分割策略和剪枝策略來提高搜尋效率。

距離計算

歐式距離

一說到距離計算,最開始想到的就是歐氏距離(歐幾里得距離 Euclidean distance),表示在m維空間中兩個點之間的真實距離。公式為

\[d = \sqrt{\sum_{i=1}^{k}(x_i - y_i)^2} \]

在二維和三維空間中的歐式距離的就是兩點之間的距離,其中二維空間的表示為

\[d = \sqrt{(x_1 - x_2)^2 + (y_1 - y_2)^2} \]

其中 \((x_1, x_2), (y_1, y_2)\) 是兩點的座標

我們這裡計算兩地之間的距離,看起來似乎是用二維空間兩點之間的距離就行了,同學也是這麼說的。

曲面上的兩點距離

不過轉念一想,不對啊,這個只是在平面上計算距離,但地球不是個平面,有曲率的啊

於是查了下資料,找到了這個 Haversine(半正矢)公式

Haversine 名字來歷是 Ha-VERSINE,即 Half-Versine ,表示 sin 的一半的意思。Haversine公式給出了用兩點經緯度計算兩點在球面上的距離的方式。

\[haversin(\frac{d}{R}) = haversin(\varphi_2 - \varphi_1) + cos(\varphi_1)cos(\varphi_2)haversin(\Delta\lambda) \]

其中

  • d 是沿著球體大圓的兩點之間的距離(參見球面距離)
  • R 為球體半徑,地球半徑可取平均值 6371km
  • φ1, φ2 表示兩點的緯度
  • Δλ 表示兩點經度的差值

上面應用到圓心角 θ 以及緯度和經度差值的半正矢函式 hav(θ) 的定義為

\[hav(\theta) = \sin^2(\frac{\theta}{2}) = \frac{1-\cos \theta}{2} \]

關於這個公式的更詳細推導過程就省略不表了,感興趣的同學請查閱參考資料。

根據半正矢的定義, archaversine(反半正弦)可以用反正弦表示:

\[archav(h) = 2 \arcsin \sqrt{h} \]

其中 \(0 \le h \le 1\)

代入半正矢公式可得到距離 d 的求解公式:

\[d = 2R \arcsin (\sqrt{\sin^2(\frac{\varphi_2-\varphi_1}{2}) + \cos \theta_2 \cdot \sin^2(\frac{\lambda_2 - \lambda_1}{2})}) \]

PS: 這部分的數學顧問: @Wyu-Cnk

接下來是使用 C# 實現 Haversine 公式計算兩點之間距離(以公里為單位)

public static class DistanceCalculator {
  // Radius of the Earth in kilometers
  private const double EarthRadius = 6371;

  public static double CalculateDistance(ILocation location1, ILocation location2) {
    var lat1 = DegreeToRadian(location1.Lat);
    var lon1 = DegreeToRadian(location1.Lng);
    var lat2 = DegreeToRadian(location2.Lat);
    var lon2 = DegreeToRadian(location2.Lng);

    var dlon = lon2 - lon1;
    var dlat = lat2 - lat1;

    var a = Math.Pow(Math.Sin(dlat / 2), 2) + Math.Cos(lat1) * Math.Cos(lat2) * Math.Pow(Math.Sin(dlon / 2), 2);
    var c = 2 * Math.Atan2(Math.Sqrt(a), Math.Sqrt(1 - a));

    var distance = EarthRadius * c;
    return distance;
  }

  private static double DegreeToRadian(double degree) {
    return degree * Math.PI / 180;
  }
}

初步實現 KD 樹

首先使用 C# 實現基本的 KD 樹,包括 KD 樹的構建,以及最近節點的搜尋功能。

資料結構

為了使程式更加通用,本文使用介面以及泛型對資料結構進行了一定程度的抽象。

地點的資料結構

public interface ILocation {
  public double Lng { get; set; }
  public double Lat { get; set; }
}

public interface ILocationNode {
  public ILocation Location { get; set; }
}

KD 樹節點

public class KdTreeNode<T> where T : ILocationNode {
  public T Value { get; set; }
  public KdTreeNode<T>? Left { get; set; }
  public KdTreeNode<T>? Right { get; set; }
}

KDTree 類

需要用到一個泛型引數,與上述的 KdTreeNode<T> 類中的泛型引數相同。

/// <summary>
/// k-dimensional tree
/// </summary>
/// <typeparam name="T"></typeparam>
public class KdTree<T> where T : ILocationNode {
  private readonly IProgress<string> _progress = new Progress<string>(Console.WriteLine);
  private KdTreeNode<T>? _root;
}

KD 樹構建

使用遞迴來構建 KD 樹。

public void BuildTree(List<T> items) {
  _root = BuildTree(items, 0, items.Count - 1, 0);
}

private KdTreeNode<T>? BuildTree(List<T> items, int start, int end, int depth) {
  if (start > end)
    return null;

  var axis = depth % 2; // Assuming latitude and longitude as 2D coordinates

  items.Sort(
    (a, b) => axis == 0
    ? a.Location.Lng.CompareTo(b.Location.Lng)
    : a.Location.Lat.CompareTo(b.Location.Lat)
  );

  var medianIndex = start + (end - start) / 2;
  var node = new KdTreeNode<T> {
    Value = items[medianIndex],
    Left = BuildTree(items, start, medianIndex - 1, depth + 1),
    Right = BuildTree(items, medianIndex + 1, end, depth + 1)
  };

  // 進度顯示
  _progress.Report($"Splitting on depth {depth}, axis {axis}, median index {medianIndex}");

  return node;
}

使用方式

準備好一批地點資料,轉換為 KdNode<T> 型別,作為 BuildTree 方法的引數執行。

class Store {
  public string Name { get; set; }
  public string Address { get; set; }
  public Location? Location { get; set; }
}

var stores = new List<Store>(){
  // Example coordinates for New York City
  new Store {
    Name = "Store1",
    Location = new Location { lat = 40.7128, lng = -74.0060 }
  },
  // Example coordinates for Los Angeles
  new Store {
    Name = "Store2",
    Location = new Location { lat = 34.0522, lng = -118.2437 }
  },
  // 更多資料請從外部資料來源載入
};

var kdTree = new KdTree<KdNode<Store>>();
kdTree.BuildTreeParallel(stores.Select(e => new KdNode<Store> {
  Node = e,
  Location = e.Location
}).ToList());

最近節點查詢

KD 樹的查詢過程也是基於遞迴,輸入引數為 ILocation 地點座標。

public KdTreeNode<T>? FindNearestNode(ILocation location) {
  return FindNearestNode(_root, location, 0, null);
}

private KdTreeNode<T>? FindNearestNode(KdTreeNode<T>? node, ILocation location, int depth, KdTreeNode<T>? best) {
  if (node == null) return best;

  var bestDistance = best != null
    ? DistanceCalculator.CalculateDistance(location, best.Value.Location)
    : double.PositiveInfinity;
  var currentNodeDistance = DistanceCalculator.CalculateDistance(location, node.Value.Location);

  if (currentNodeDistance < bestDistance)
    best = node;

  var axis = depth % 2;
  var axisDistance =
    axis == 0 ? location.Lng - node.Value.Location.Lng : location.Lat - node.Value.Location.Lat;

  var nearChild = axisDistance < 0 ? node.Left : node.Right;
  var farChild = axisDistance < 0 ? node.Right : node.Left;

  var nearest = FindNearestNode(nearChild, location, depth + 1, best);

  if (nearest != null) {
    var nearestDistance = DistanceCalculator.CalculateDistance(location, nearest.Value.Location);
    if (nearestDistance < bestDistance)
      best = nearest;
  }

  if (Math.Abs(axisDistance) < bestDistance) {
    var farthest = FindNearestNode(farChild, location, depth + 1, best);
    if (farthest != null) {
      var farthestDistance = DistanceCalculator.CalculateDistance(location, farthest.Value.Location);
      if (farthestDistance < bestDistance)
        best = farthest;
    }
  }

  return best;
}

使用方式

// Example coordinates for Tokyo City
var location = new Location { lat = 35.652832, lng = 139.839478 }
var nearestNode = kdTree.FindNearestNode(location);

構建效能

使用 C# 的 StopWatch 工具對 BuildTree 方法進行計時,資料集的大小為 45066 的情況下,上述的程式碼構建 KD 樹耗時為 06:35.9083329 即 6 分鐘 35 秒。

Building KD tree...
KD tree construction complete. 耗時: 00:06:35.9083329

以下是本文程式碼執行環境的 CPU 資訊

12th Gen Intel(R) Core(TM) i7-12700

基準速度:	2.10 GHz
插槽:	1
核心:	12
邏輯處理器:	20
虛擬化:	已啟用
L1 快取:	1.0 MB
L2 快取:	12.0 MB
L3 快取:	25.0 MB

檢查 CPU 各個核心的利用率情況,發現 CPU 只有 2 個核心處於 100% 利用率,整體負載為 20% 左右。

查詢效能

使用上述程式碼中的 FindNearestNode 查詢節點,執行耗時

00:00:00.0242371

效能最佳化

目前的 KD 樹實現已經能夠滿足本文場景的需求,在效能方面,構建過程的耗時較長,查詢過程的速度在可接受範圍內,所以效能最佳化的重點放在了構建過程上。

在構建過程中檢查 CPU 各個核心的利用率情況,發現 CPU 只有 2 個核心處於 100% 利用率,整體負載為 20% 左右,究其原因,如果沒有特別的最佳化,普通的程式碼是以單執行緒的形式執行,現在的 CPU 都是多核架構,單執行緒只能利用到很少的 CPU 效能,因此我的最佳化思路是提高 CPU 的多核利用率。

在 C# 中,有多種方式來讓程式實現並行執行,一般會選擇 Parallel 或者是 Task 類提供的方法,Thread 可以直接對執行緒進行管理操作,是比較底層的實現,一般不直接管理執行緒。

Task 的抽象級別更高,使用 Task (包括 async/await 語句)建立的任務會線上程池中排程執行,,Task 類似於 go 語言中的「協程(Goroutine)」概念(區別在於 Task 透過編譯器+狀態機實現,在編譯期間完成,屬於無棧協程;Goroutine 則是有棧協程)。

Parallel vs Task

Parallel

  • Parallel 類提供了一種簡單的方法來執行並行迴圈或迭代。它允許指定一個迴圈範圍,並自動將其分割成較小的任務,然後並行執行這些任務。
  • Parallel 類的主要目的是簡化並行迴圈的編寫,使得開發者可以輕鬆地利用多核處理器來提高效能,而不必擔心管理執行緒或任務的細節。
  • Parallel 類通常用於處理可迭代的資料集合,比如陣列、列表等。

Task

  • Task 類提供了更加靈活和底層的並行程式設計機制。可以使用 Task 類來建立和管理非同步操作,每個 Task 例項代表一個可執行的操作單元。
  • Task 類允許建立具有不同執行策略和排程選項的任務,例如使用執行緒池執行緒或新執行緒執行任務,也可以設定任務的優先順序、取消任務等。
  • Task 類更適合處理非同步操作,而不僅僅是並行迴圈。可以使用 Task 類來執行任何需要非同步執行的操作,比如非同步檔案 I/O、網路請求等。

小結

  • Task 的擴充套件性比 Parallel 更強
  • Parallel 的效能比 Task 更強
  • Parallel.ForEach 更適合 CPU 密集型任務
  • Task.WhenAll 更適合 IO 密集型任務

針對本文的場景,屬於 CPU 密集型任務,使用 Parallel 來實現的效能會更好一點。為了做對比,本文會分別實現 ParallelTask 兩個版本。

使用 Task 最佳化效能

重構 BuildTree 方法

KdTree<T> 中新增 BuildTreeParallel 方法。

public void BuildTreeParallel(List<T> items) {
  Console.WriteLine("Building KD tree...");
  _root = BuildTreeParallel(items, 0, items.Count - 1, 0);
  Console.WriteLine("KD tree construction complete.");
}

private KdTreeNode<T>? BuildTreeParallel(List<T> items, int start, int end, int depth) {
  if (start > end)
    return null;

  var axis = depth % 2; // Assuming latitude and longitude as 2D coordinates

  // 複製一份新的 items 列表
  var sortedItems = items.ToList();

  sortedItems.Sort(
    (a, b) => axis == 0
    ? a.Location.Lng.CompareTo(b.Location.Lng)
    : a.Location.Lat.CompareTo(b.Location.Lat)
  );

  var medianIndex = start + (end - start) / 2;
  var node = new KdTreeNode<T> {
    Value = sortedItems[medianIndex]
  };

  var leftTask = Task.Run(() => BuildTreeParallel(sortedItems, start, medianIndex - 1, depth + 1));
  var rightTask = Task.Run(() => BuildTreeParallel(sortedItems, medianIndex + 1, end, depth + 1));

  node.Left = leftTask.Result;
  node.Right = rightTask.Result;

  return node;
}

程式碼邏輯與單執行緒版本基本相同,區別在於並行版本將每個節點的葉子節點放在一個新的 Task 中構建。

還有 items 的排序部分,在多執行緒版本中不能直接進行排序,而是要複製一份新的副本進行排序。高併發的場景下,對同一個 items 物件進行排序會引發一致性錯誤問題,具體報錯表現為:

System.AggregateException: One or more errors occurred. (Unable to sort because the IComparer.Compare() method returns inconsistent result
s. Either a value does not compare equal to itself, or one value repeatedly compared to another value yields different results. IComparer:
'System.Comparison`1[StoreProximityBroadbandLocator.Services.KdNode`1[StoreProximityBroadbandLocator.Models.Device]]'.)

大意為:比較器(IComparer)的比較方法返回了不一致的結果。複製一份新的 items 副本進行排序可以解決這個問題。

重構 FindNearestNode 方法

事實上,查詢節點的計算量不大,使用並行最佳化的效果不明顯,反而會因為執行緒排程帶來額外的開銷。

本文提供基於 Task 的實現,但並不提倡使用這種方式。

public KdTreeNode<T>? FindNearestNodeParallel(ILocation location) {
  return FindNearestNodeParallel(_root, location, 0, null);
}

private KdTreeNode<T>? FindNearestNodeParallel(KdTreeNode<T>? node, ILocation location, int depth, KdTreeNode<T>? best) {
  if (node == null) return best;

  var bestDistance = best != null
    ? DistanceCalculator.CalculateDistance(location, best.Value.Location)
    : double.PositiveInfinity;
  var currentNodeDistance = DistanceCalculator.CalculateDistance(location, node.Value.Location);

  if (currentNodeDistance < bestDistance)
    best = node;

  var axis = depth % 2;
  var axisDistance =
    axis == 0 ? location.Lng - node.Value.Location.Lng : location.Lat - node.Value.Location.Lat;

  var nearChild = axisDistance < 0 ? node.Left : node.Right;
  var farChild = axisDistance < 0 ? node.Right : node.Left;

  var nearestTask = Task.Run(() => FindNearestNodeParallel(nearChild, location, depth + 1, best));
  var nearest = nearestTask.Result;

  if (nearest != null) {
    var nearestDistance = DistanceCalculator.CalculateDistance(location, nearest.Value.Location);
    if (nearestDistance < bestDistance)
      best = nearest;
  }

  if (Math.Abs(axisDistance) < bestDistance) {
    var farthestTask = Task.Run(() => FindNearestNodeParallel(farChild, location, depth + 1, best));
    var farthest = farthestTask.Result;
    if (farthest != null) {
      var farthestDistance = DistanceCalculator.CalculateDistance(location, farthest.Value.Location);
      if (farthestDistance < bestDistance)
        best = farthest;
    }
  }

  return best;
}

效能測試

與單執行緒版本一樣,執行 KD 樹構建與查詢節點方法。

現在可以跑滿整個 CPU 了,本文的開發環境使用的 CPU 為 Intel i7-12700,全核頻率最高到 3.6GHz 左右。

基於 Task 實現的 BuildTree 方法執行耗時為 48 秒左右。

使用 Parallel 最佳化效能

重構 BuildTree 方法

KdTree<T> 中新增 BuildTreeParallel2 方法。

public void BuildTreeParallel2(List<T> items) {
  _root = new KdTreeNode<T>();
  BuildTreeParallel2(_root, items, 0, items.Count - 1, 0);
  Console.WriteLine("KD tree construction complete.");
}

private void BuildTreeParallel2(KdTreeNode<T> node, List<T> items, int start, int end, int depth) {
  if (start > end)
    return;

  var axis = depth % 2; // Assuming latitude and longitude as 2D coordinates

  // 複製一份新的 items 列表
  var sortedItems = items.ToList();

  sortedItems.Sort(
    (a, b) => axis == 0
    ? a.Location.Lng.CompareTo(b.Location.Lng)
    : a.Location.Lat.CompareTo(b.Location.Lat)
  );

  var medianIndex = start + (end - start) / 2;
  node.Value = sortedItems[medianIndex];

  if (start < medianIndex) {
    node.Left = new KdTreeNode<T>();
  }

  if (medianIndex < end) {
    node.Right = new KdTreeNode<T>();
  }

  Parallel.Invoke(
    () => {
      if (node.Left != null) {
        BuildTreeParallel2(node.Left, sortedItems, start, medianIndex - 1, depth + 1);
      }
    },
    () => {
      if (node.Right != null) {
        BuildTreeParallel2(node.Right, sortedItems, medianIndex + 1, end, depth + 1);
      }
    }
  );
}

錯誤示例

最開始的時候我寫了個錯誤的程式碼

構建的時候沒問題,但每次查詢的時候總報 null reference 異常

經過 debug 才找出來錯誤的原因

這裡附上最開始的程式碼,然後再來分析一下

node.Left = new KdTreeNode<T>();
node.Right = new KdTreeNode<T>();

Parallel.Invoke(
  () => BuildTreeParallel2(node.Left, sortedItems, start, medianIndex - 1, depth + 1),
  () => BuildTreeParallel2(node.Right, sortedItems, medianIndex + 1, end, depth + 1)
);

舊版的程式碼比較簡潔

看起來似乎很好,但是因為最開始有一行 if (start > end) return;

所以在已經構建完成了某個分支之後,節點的末端還會生成一個空的節點

也不完全是空的,就是一個 KdTreeNode<T> 物件,但裡面的 Value , Location 屬性全是 null

所以只能改成我上面那種方式,既能完全利用CPU,也能解決這個問題。

效能測試

使用 C# 的 StopWatch 工具測試,基於 Parallel 實現的 BuildTree 方法執行耗時為 43 秒左右,對比 Task 的實現確實會快一些。

參考資料

  • https://en.wikipedia.org/wiki/Haversine_formula
  • 一隻兔子幫你理解 kNN
  • kd 樹演算法之思路篇
  • kd 樹演算法之詳細篇 - https://zhuanlan.zhihu.com/p/23966698
  • KD-Tree原理詳解 - https://zhuanlan.zhihu.com/p/112246942
  • A Guide to Parallel Execution in C# : Harness the Power of the Task Parallel Library (TPL)
  • From C# to golang: less is more
  • When would you use Parallel.ForEachAsync() and when Task.WhenAll()
  • Task.WhenAll vs Parallel.Foreach

相關文章