計算幾何——平面最近點對

RainPPR發表於2024-05-23

計算幾何——平面最近點對

問題描述

給定平面上 \(n\)\(n\ge2\))個點,找出一堆點,使得其間距離最短。

下午將介紹分治做法、非分治做法,以及期望線性做法。

其中執行速度(P1429 平面最近點對 加強版)大致為,非分治做法最快,期望線性做法最慢。

樸素演算法

非常顯然了,\(\mathcal O(n^2)\) 的遍歷每一對點。

非常簡略的一個程式碼示例,

function solve():
	min_dist = inf
	for i in point_set:
		for j in point_set:
			min_dist = min(min_dist, dist(i,j))
	return min_dist

簡單剪枝

考慮一種常見的統計序列的思想:

依次加入每一個元素,統計它和其左邊所有元素的貢獻。

具體地,

  • 我們把所有點按照 \(x_i\) 為第一關鍵字、\(y_i\) 為第二關鍵字排序。

  • 同時,建立一個以 \(y_i\) 為第一關鍵字、 \(x_i\) 為第二關鍵字排序的 multiset

  • 對於每一個位置 \(i\),我們執行以下操作:

  1. 假設我們已經算出來的最小距離是 \(d\)

  2. 將所有滿足 \(|x_i-x_j|\ge d\) 的點從集合中刪除。它們不會再對答案有貢獻。

  3. 對於集合內滿足 \(|y_i-y_j|< d\) 的所有點,統計它們和 \((x_i,y_i)\) 的距離。

  4. \((x_i,y_i)\) 插入到集合中。

這個演算法的複雜度為 \(\mathcal O(n\log n)\),比分治做法常數略小,證明略。

程式碼:

#include <bits/stdc++.h>

using namespace std;

using ll = long long;

struct emm {
    ll x, y;
    emm() = default;
    emm(ll x, ll y): x(x), y(y) {}
    friend ll operator *(const emm &a, const emm &b) {
        return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
    }
};

struct cmp_x {
    bool operator ()(const emm &a, const emm &b) const {
        return a.x == b.x ? a.y < b.y : a.x < b.x;
    }
};

struct cmp_y {
    bool operator ()(const emm &a, const emm &b) const {
        return a.y < b.y;
    }
};

double ans = 1e10;

void upd_ans(const emm &a, const emm &b) {
    double dist = sqrt(a * b);
    if (dist < ans) ans = dist;
}

vector<emm> a;

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    int n; cin >> n; a.resize(n);
    for (int i = 0; i < n; ++i) cin >> a[i].x >> a[i].y;
    sort(a.begin(), a.end(), cmp_x());
    multiset<emm, cmp_y> s;
    for (int i = 0, l = 0; i < n; ++i) {
        while (l < i && a[i].x - a[l].x >= ans) s.erase(s.find(a[l++]));
        auto it = s.lower_bound(emm(a[i].x, a[i].y - ans));
        for (; it != s.end() && it->y - a[i].y < ans; ++it) upd_ans(a[i], *it);
        s.insert(a[i]);
    }
    printf("%.4lf\n", ans);
    return 0;
}

這個做法的問題在於 multiset 的大常數,但是好寫。

分治演算法

考慮如果我們把所有點按照 \(x_i\) 排序,分治解決。

  1. 假設我們已經算出來的最小距離是 \(d\)

  2. 考慮如何合併,顯然只有兩個集合分界線處各 \(d\) 距離內的點需要考慮。

  3. 我們列舉這個小集合內的點,計算每個點向下最多 \(d\) 個單位的點的貢獻。

因為當前最小距離 \(d\)、向下列舉的是 \(d\times2d\) 的矩陣,其內部的點的個數是 \(\mathcal O(1)\) 的。

因此,整體複雜度即考慮分治的複雜度,即 \(\mathcal O(n\log n)\),但是常數比非分治略大。

程式碼:

#include <bits/stdc++.h>

using namespace std;

using ll = long long;
using db = double;

struct emm {
    ll x, y;
    emm() = default;
    emm(ll x, ll y): x(x), y(y) {}
    friend ll operator *(const emm &a, const emm &b) {
        return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
    }
};

struct cmp_x {
    bool operator ()(const emm &a, const emm &b) const {
        return a.x == b.x ? a.y < b.y : a.x < b.x;
    }
};

struct cmp_y {
    bool operator ()(const emm &a, const emm &b) const {
        return a.y < b.y;
    }
};

double ans = 1e10;

void upd_ans(const emm &a, const emm &b) {
    double dist = sqrt(a * b);
    if (dist < ans) ans = dist;
}

vector<emm> a;

void merge(int l, int r) {
    if (l == r) return;
    int m = l + r >> 1; ll mx = a[m].x;
    merge(l, m), merge(m + 1, r);
    inplace_merge(a.begin() + l, a.begin() + m + 1, a.begin() + r + 1, cmp_y());
    vector<emm> t;
    for (int i = l; i <= r; ++i) {
        if (abs(a[i].x - mx) >= ans) continue;
        for (auto j = t.rbegin(); j != t.rend(); ++j) {
            if (a[i].y - j->y >= ans) break;
            upd_ans(a[i], *j);
        }
        t.push_back(a[i]);
    }
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    int n; cin >> n; a.resize(n);
    for (int i = 0; i < n; ++i) cin >> a[i].x >> a[i].y;
    sort(a.begin(), a.end(), cmp_x());
    merge(0, n - 1);
    printf("%.4lf\n", ans);
    return 0;
}

使用了 std::inplace_merge 作為歸併,詳見 cppreference。

期望線性

注意是期望線性做法,複雜度理論期望值是 \(\mathcal O(n)\) 的。

但是實際上常數巨大,而且容易被卡,實測速度反而最慢。

  1. 同樣我們考慮加入一個點的貢獻,但是這裡需要先隨機打亂。

  2. 記前 \(i-1\) 個點的最近點對距離為 \(d\),將平面以 \(d\) 為邊長劃分為若干個網格。

  3. 檢查第 \(i\) 個點所在網格的周圍九個網格中的所有點,並更新答案。

  4. 使用雜湊表存下每個網格內的點,如果答案被更新,就重構網格圖,否則不重構。

因為前 \(i-1\) 個點的最近點對距離為 \(d\),從而每個網格不超過 \(4\) 個點。

注意到需檢查的點的個數是 \(\mathcal O(1)\) 的,在前 \(i\) 個點中,最近點對包含 \(i\) 的機率為
\(\mathcal O(1/i)\)

而重構網格的代價為 \(\mathcal O(i)\),從而第 \(i\) 個點的期望代價為 \(\mathcal O(1)\)

於是對於 \(n\) 個點,該演算法期望為 \(\mathcal O(n)\)

程式碼:

#include <bits/stdc++.h>

using namespace std;

struct my_hash {
  static uint64_t splitmix64(uint64_t x) {
    x += 0x9e3779b97f4a7c15;
    x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
    x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
    return x ^ (x >> 31);
  }

  size_t operator()(uint64_t x) const {
    static const uint64_t FIXED_RANDOM =
        chrono::steady_clock::now().time_since_epoch().count();
    return splitmix64(x + FIXED_RANDOM);
  }

  size_t operator()(pair<uint64_t, uint64_t> x) const {
    static const uint64_t FIXED_RANDOM =
        chrono::steady_clock::now().time_since_epoch().count();
    return splitmix64(x.first + FIXED_RANDOM) ^
           (splitmix64(x.second + FIXED_RANDOM) >> 1);
  }
};

mt19937 rng(time(0));

using ll = long long;
using grid = pair<int, int>;

struct emm {
    ll x, y;
    emm() = default;
    emm(ll x, ll y): x(x), y(y) {}
    friend ll operator *(const emm &a, const emm &b) {
        return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
    }
};

double ans = 1e10;

void upd_ans(const emm &a, const emm &b) {
    double dist = sqrt(a * b);
    if (dist < ans) ans = dist;
}

vector<emm> a;

unordered_map<grid, vector<emm>, my_hash> ump;

#define group(e, t) make_pair(e.x / (int)t, e.y / (int)t)

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);
    int n; cin >> n; a.resize(n);
    for (int i = 0; i < n; ++i) cin >> a[i].x >> a[i].y;
    shuffle(a.begin(), a.end(), rng);
    for (int i = 0; i < n; ++i) {
        double lt = ans;
        int tx, ty; tie(tx, ty) = group(a[i], lt);
        for (int kx = tx - 1; kx <= tx + 1; ++kx) {
            for (int ky = ty - 1; ky <= ty + 1; ++ky) {
                auto eq = make_pair(kx, ky);
                if (!ump.count(eq)) continue;
                for (emm j : ump[eq]) upd_ans(a[i], j);
            }
        }
        if (ans == 0) break;
        if (ans != lt) {
            ump = decltype(ump)();
            for (int j = 0; j < i; ++j) ump[group(a[j], ans)].push_back(a[j]);
        }
        ump[group(a[i], ans)].push_back(a[i]);
    }
    printf("%.4lf\n", ans);
    return 0;
}

這個演算法的常數很大,主要在於雜湊(程式碼裡手寫了雜湊函式)。

相關文章