KDTree求平面最近點對

HarlemBlog發表於2024-11-12
更新日誌

思路

對於每一個點都求一邊其最短距離,最後統計。

詳細地,對於每個區間,先與其中點判斷並更新距離(注意特判不能是同一點),然後找出在這一節點排序維度上,查詢點到中點距離,記作 \(D\)

看查詢點在中點左側/右側,判斷去左右區間。(在這一節點排序的維度上。)

同側更新完之後,如果一側的答案要大於當前點到中點的一維距離,就說明答案的另一點可能在另一側,再查詢另一側的子區間。

模板

struct node{
    int id;
    ll v[K];
    int lson,rson;
}ns[N];
int n;

int ak;
bool cmp(node a,node b){return a.v[ak]<b.v[ak];}

ll dis(int a,int b){
    ll res=0;
    for(int k=0;k<K;k++){
        res+=(ns[a].v[k]-ns[b].v[k])*(ns[a].v[k]-ns[b].v[k]);
    }
    return res;
}
ll dis(int a,int b,int k){
    return (ns[a].v[k]-ns[b].v[k])*(ns[a].v[k]-ns[b].v[k]);
}

int mp[N];

struct kdtree{
    int build(int l,int r,int k=0){
        int m=l+r>>1;
        ak=k;nth_element(ns+l,ns+m,ns+r+1,cmp);
        mp[ns[m].id]=m;
        if(l<m)ns[m].lson=build(l,m-1,(k+1)%K);
        if(m<r)ns[m].rson=build(m+1,r,(k+1)%K);
        return m;
    }
    ll query(int q,int x,int k=0){
        ll res=INF,tes=INF;
        if(x!=q)res=min(res,dis(x,q));
        tes=dis(x,q,k);
        if(ns[q].v[k]<=ns[x].v[k]){
            if(ns[x].lson)res=min(res,query(q,ns[x].lson,(k+1)%K));
            if(tes<res&&ns[x].rson)res=min(res,query(q,ns[x].rson,(k+1)%K));
        }else{
            if(ns[x].rson)res=min(res,query(q,ns[x].rson,(k+1)%K));
            if(tes<res&&ns[x].lson)res=min(res,query(q,ns[x].lson,(k+1)%K));
        }
        return res;
    }
}kdt;

例題

HDU2966

程式碼

前注:就是模板。

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef __int128 i128;
typedef double db;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
template <typename Type>
using vec=vector<Type>;
template <typename Type>
using grheap=priority_queue<Type>;
template <typename Type>
using lrheap=priority_queue<Type,vector<Type>,greater<Type> >;
#define fir first
#define sec second
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define dprint(x) cout<<#x<<"="<<x<<"\n";

const int inf=0x3f3f3f3f;
const ll INF=4e18;
const int mod=1e9+7/*998244353*/;

const int N=1e5+5,K=2;

struct node{
    int id;
    ll v[K];
    int lson,rson;
}ns[N];
int n;

int ak;
bool cmp(node a,node b){return a.v[ak]<b.v[ak];}

ll dis(int a,int b){
    ll res=0;
    for(int k=0;k<K;k++){
        res+=(ns[a].v[k]-ns[b].v[k])*(ns[a].v[k]-ns[b].v[k]);
    }
    return res;
}
ll dis(int a,int b,int k){
    return (ns[a].v[k]-ns[b].v[k])*(ns[a].v[k]-ns[b].v[k]);
}

int mp[N];

struct kdtree{
    int build(int l,int r,int k=0){
        int m=l+r>>1;
        ak=k;nth_element(ns+l,ns+m,ns+r+1,cmp);
        mp[ns[m].id]=m;
        if(l<m)ns[m].lson=build(l,m-1,(k+1)%K);
        if(m<r)ns[m].rson=build(m+1,r,(k+1)%K);
        return m;
    }
    ll query(int q,int x,int k=0){
        ll res=INF,tes=INF;
        if(x!=q)res=min(res,dis(x,q));
        tes=dis(x,q,k);
        if(ns[q].v[k]<=ns[x].v[k]){
            if(ns[x].lson)res=min(res,query(q,ns[x].lson,(k+1)%K));
            if(tes<res&&ns[x].rson)res=min(res,query(q,ns[x].rson,(k+1)%K));
        }else{
            if(ns[x].rson)res=min(res,query(q,ns[x].rson,(k+1)%K));
            if(tes<res&&ns[x].lson)res=min(res,query(q,ns[x].lson,(k+1)%K));
        }
        return res;
    }
}kdt;

void solve(){
    cin>>n;
    for(int i=1;i<=n;i++){
        ns[i].id=i;
        cin>>ns[i].v[0]>>ns[i].v[1];
        ns[i].lson=ns[i].rson=0;
    }
    int rt=kdt.build(1,n);
    for(int i=1;i<=n;i++){
        cout<<kdt.query(mp[i],rt)<<"\n";
    }
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    int t;cin>>t;
    while(t--)solve();
    return 0;
}

相關文章