【資料結構】K-D Tree

HinanawiTenshi發表於2022-01-26

K-D Tree

這東西是我入坑 ICPC 不久就聽說過的資料結構,但是一直沒去學 QAQ,終於在昨天去學了它。還是挺好理解的,而且也有用武之地。

目錄

簡介

建樹過程

性質

操作

例題

簡介

K-D Tree(KDT , k-Dimension Tree) 是一種可以 高效處理 \(k\) 維空間資訊 的資料結構。更具體地說,它是維護了 \(k\) 維空間 \(n\) 個點的資料結構,而且它是一棵平衡樹

建樹過程

由於二維的形式是競賽中最常見而且便於講解,故以二維的情況為例。所以下面構建的都是 2-D Tree。

先使用一個具體的例子來模擬建樹的過程:

image

現在如圖給出一個點集,接下來開始用這個點集來構建 2-D Tree。

首先,我們以 \(x\) 軸座標(即第一維座標)為關鍵字,選取中位數所對應的點為根節點,這裡也就是 \(A\) 點。

image

類似地,以 \(y\) 軸座標(即第二維座標)為關鍵字,選取中位數所對應的點為根節點,這裡也就是 \(C\)\(E\) 點。

image

與上面類似,再以 \(x\) 軸座標(即第一維座標)為關鍵字繼續構建:

image

至此,所給出來的點集已經通過一棵二叉樹維護了起來。

總結一下建樹的過程:

  • 以當前的關鍵字選取中位數所對應的點作為當前子樹的根節點。

  • 交替地選擇不同維度為關鍵字(如果是 \(k\) 維的座標系那麼假設先前選了第 \(x\) 維,下一維就是 \(x\%k+1\)

  • 將根節點向左右兒子連邊。

性質

  1. 樹的每一層都是按同一個關鍵字劃分。(從上面的建樹過程可以明顯地看出)
  2. 樹的一棵子樹可以劃分出一個矩形(二維)。比如下面的 \(D,E,F\)​​,只要我們將每個點的座標維護起來,那麼矩形的左下端點就是所有點 \(x,y\)​ 座標的最小值,右上端點就是所有點 \(x,y\) 座標的最大值。
    image

操作

接下來說一下 K-D Tree 的關鍵操作。

\(\texttt{insert}\)

也就是插入操作,和普通的二叉搜尋樹類似,從根節點開始比較,決定向左子樹還是右子樹移動,最後走到需要插入的位置。

\(\texttt{rebuild}\)

重構操作。

顯然,上面的插入操作並不能夠保證我們的樹高為 \(logN\),所以我們需要進行專門的操作來維持樹高,換句話說,我們要保證這棵二叉搜尋樹為平衡樹

怎麼保證平衡呢?由上面的性質 \(1\) 可知我們不能對樹進行旋轉,那麼可以利用替罪羊樹的思想:重構。引入重構常數 \(\alpha\),在執行插入操作後,如果發現當前的子樹的根節點的左子樹或右子樹的大小佔整棵子樹的大小超過 \(\alpha\),那麼我們就進行重構

上面說到 K-D Tree 類似於二叉搜尋樹,因此可以通過它的中序遍歷得到一個序列,我們利用這個序列進行重構就可以了。

\(\texttt{query}\) 操作因情況而異,故於例題介紹。

例題

例 1:

傳送門:

https://www.luogu.com.cn/problem/P4148

分析

操作 \(1\) 就是 \(\texttt{insert}\),只不過帶了點權。

操作 \(2\)​​​ 是對矩形進行詢問,由上面所說的性質 \(2\),K-D Tree 的子樹正好可以劃分出一個矩形,因此我們可以採取類似於線段樹區間查詢的做法:

  • 從根節點出發開始查詢。
  • 如果當前子樹所對應的矩形和查詢的矩形沒有交集,返回 \(0\)
  • 如果被當前子樹所對應的矩形被查詢的矩形包含,直接返回當前子樹的權值和 \(sum\)
  • 否則向左右子樹遞迴繼續詢問。

程式碼

#include<bits/stdc++.h>
using namespace std;

#define debug(x) cerr << #x << ": " << (x) << endl
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define dwn(i,a,b) for(int i=(a);i>=(b);i--)

inline void read(int &x){
    int s=0; x=1;
    char ch=getchar();
    while(ch<'0' || ch>'9') {if(ch=='-')x=-1;ch=getchar();}
    while(ch>='0' && ch<='9') s=(s<<3)+(s<<1)+ch-'0',ch=getchar();
    x*=s;
}

const int N=5e5+5;

struct Point{
	int x[2], w;
};

struct Node{
	int l, r;
	Point P;
	int L[2], R[2], sum, sz;
	
	#define ls tr[u].l
	#define rs tr[u].r
}tr[N];

int n;

int idx, root;
int buf[N], tot;
int add(){
	if(!tot) return ++idx;
	return buf[tot--];
}

void pushup(int u){
	auto &L=tr[ls], &R=tr[rs];
	tr[u].sum=tr[u].P.w+L.sum+R.sum, tr[u].sz=L.sz+R.sz+1;

	rep(i,0,1){
		tr[u].L[i]=min(tr[u].P.x[i], min(L.L[i], R.L[i]));
		tr[u].R[i]=max(tr[u].P.x[i], max(L.R[i], R.R[i]));
	}
}

const double Al=0.72;

Point pt[N];

void getSeq(int u, int cnt){
	if(ls) getSeq(ls, cnt);
	buf[++tot]=u, pt[tr[ls].sz+1+cnt]=tr[u].P;
	if(rs) getSeq(rs, cnt+tr[ls].sz+1);
}

int rebuild(int l, int r, int k){
	if(l>r) return 0;
	int mid=l+r>>1;
	int u=add();
	
	nth_element(pt+l, pt+mid, pt+r+1, [&](Point a, Point b){
		return a.x[k]<b.x[k];
	});
	tr[u].P=pt[mid];
	
	ls=rebuild(l, mid-1, k^1), rs=rebuild(mid+1, r, k^1);
	pushup(u);
	return u;
}

void maintain(int &u, int k){
	if(tr[u].sz*Al<tr[ls].sz || tr[u].sz*Al<tr[rs].sz)
		getSeq(u, 0), u=rebuild(1, tot, k);	
}

void insert(int &u, Point p, int k){
	if(!u){
		u=add();
		tr[u].l=tr[u].r=0;
		tr[u].P=p, pushup(u);
		return;
	}
	if(p.x[k]<=tr[u].P.x[k]) insert(ls, p, k^1);
	else insert(rs, p, k^1);
	pushup(u);
	maintain(u, k);
}

bool In(Node t, int x1, int y1, int x2, int y2){
	return t.L[0]>=x1 && t.R[0]<=x2 && t.L[1]>=y1 && t.R[1]<=y2;	
}

bool In(Point p, int x1, int y1, int x2, int y2){
	return p.x[0]>=x1 && p.x[0]<=x2 && p.x[1]>=y1 && p.x[1]<=y2;
}

bool Out(Node t, int x1, int y1, int x2, int y2){
	return t.R[0]<x1 || t.L[0]>x2 || t.R[1]<y1 || t.L[1]>y2;	
}

int query(int u, int x1, int y1, int x2, int y2){
	if(In(tr[u], x1, y1, x2, y2)) return tr[u].sum;
	if(Out(tr[u], x1, y1, x2, y2)) return 0;
	
	int res=0;
	if(In(tr[u].P, x1, y1, x2, y2)) res+=tr[u].P.w;
	res+=query(ls, x1, y1, x2, y2)+query(rs, x1, y1, x2, y2);
	return res;
}

int main(){
	cin>>n;
	// init
	tr[0].L[0]=tr[0].L[1]=N+5;
	tr[0].R[0]=tr[0].R[1]=-1;
	
	int res=0, op;
	while(cin>>op, op!=3){
		if(op==1){
			int x, y, k; read(x), read(y), read(k);
			insert(root, {x^res, y^res, k^res}, 0);
		}	
		else{
			int x1, y1, x2, y2; read(x1), read(y1), read(x2), read(y2);
			cout<<(res=query(root, x1^res, y1^res, x2^res, y2^res))<<endl;
		}
	}

	return 0;
}

例 2:

傳送門:

https://www.acwing.com/problem/content/256/

https://www.luogu.com.cn/problem/P4169

(注意兩個 OJ 的資料有差異)

分析

操作 \(1\) 就是 \(\texttt{insert}\)

操作 \(2\) 我們採取搜尋剪枝的方法:

  • 從根節點出發詢問,記現在查詢到 \(u\) 點了。
  • \(u\)​ 點和查詢點的距離更新答案。
  • 查詢點\(u\)​​ 點的左右子樹(分別對應兩個矩形)的期望距離(我將其稱為估價函式),也就是這個點到矩形的最小曼哈頓距離。並決定是否向該子樹遞迴。

因為用到這樣的搜尋了,複雜度是 \(O(能過)\)​​​,但是實際執行效果非常不錯,我的程式碼在洛谷可以到最優解的第二頁,而且最優解的前幾名大多是用 K-D Tree 寫的;同時,在 acwing 的執行時是 \(5628 ms\)​​​,而選取幾份使用 \(cdq\)​​​ 分治的卻都在 \(10000 ms\)​​​ 左右。

// Problem: 天使玩偶
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/description/256/
// Memory Limit: 256 MB
// Time Limit: 7000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#include<bits/stdc++.h>
using namespace std;

namespace IO{
	int f; char c;

	template<typename T> inline void read(T &v){
		v = 0; f = 1; c = getchar();
		while(!isdigit(c)) { if(c == '-') f = -1; c = getchar(); }
		while(isdigit(c)) { v = (v << 3) + (v << 1) + (int)(c - '0'); c = getchar(); }
		v *= f;
		return;
	}

	template<typename T> inline void write(T k){
		if(k < 0) { putchar('-'); k = -k; }
		if(k > 9) write(k / 10);
		putchar((char)(k % 10 + '0'));
		return;
	}

	inline int Read() { int v; read(v); return v; }
	inline void Write(int v, char ed = '\n') { write(v); putchar(ed); return; }
}

#define debug(x) cerr << #x << ": " << (x) << endl
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define dwn(i,a,b) for(int i=(a);i>=(b);i--)

using IO::read;
using IO::Write;

const int N=1e6+5, INF=0x3f3f3f3f;

struct Point{
	int x[2];
};

struct Node{
	int l, r;
	Point P;
	int L[2], R[2], sz;
	
	#define ls tr[u].l
	#define rs tr[u].r
}tr[N];

int n, m;

int idx, root;
int buf[N], tot;
int add(){
	if(!tot) return ++idx;
	return buf[tot--];
}

void pushup(int u){
	auto &L=tr[ls], &R=tr[rs];
	tr[u].sz=L.sz+R.sz+1;

	rep(i,0,1){
		tr[u].L[i]=min(tr[u].P.x[i], min(L.L[i], R.L[i]));
		tr[u].R[i]=max(tr[u].P.x[i], max(L.R[i], R.R[i]));
	}
}

const double Al=0.75;

Point pt[N];

void getSeq(int u, int cnt){
	if(ls) getSeq(ls, cnt);
	buf[++tot]=u, pt[tr[ls].sz+1+cnt]=tr[u].P;
	if(rs) getSeq(rs, cnt+tr[ls].sz+1);
}

int rebuild(int l, int r, int k){
	if(l>r) return 0;
	int mid=l+r>>1;
	int u=add();
	
	nth_element(pt+l, pt+mid, pt+r+1, [&](Point a, Point b){
		return a.x[k]<b.x[k];
	});
	tr[u].P=pt[mid];
	
	ls=rebuild(l, mid-1, k^1), rs=rebuild(mid+1, r, k^1);
	pushup(u);
	return u;
}

void maintain(int &u, int k){
	if(tr[u].sz*Al<tr[ls].sz || tr[u].sz*Al<tr[rs].sz)
		getSeq(u, 0), u=rebuild(1, tot, k);	
}

void insert(int &u, Point p, int k){
	if(!u){
		u=add();
		tr[u].l=tr[u].r=0;
		tr[u].P=p, pushup(u);
		return;
	}
	if(p.x[k]<=tr[u].P.x[k]) insert(ls, p, k^1);
	else insert(rs, p, k^1);
	pushup(u);
	maintain(u, k);
}

int dis(Point a, Point b){
	return abs(a.x[0]-b.x[0])+abs(a.x[1]-b.x[1]);
}

int H(Node t, Point p){
	int x=p.x[0], y=p.x[1];
	return max(0, t.L[0]-x)+max(0, t.L[1]-y)+max(0, x-t.R[0])+max(0, y-t.R[1]);
}

int res;

void query(int u, Point p){
	if(!u) return;
	res=min(res, dis(tr[u].P, p));
	int LV=INF, RV=INF;
	if(ls) LV=H(tr[ls], p);
	if(rs) RV=H(tr[rs], p);
	
	if(LV<RV){
		if(LV<res) query(ls, p);
		if(RV<res) query(rs, p);
	}
	else{
		if(RV<res) query(rs, p);
		if(LV<res) query(ls, p);
	}
}

int main(){
	cin>>n>>m;
	// init
	tr[0].L[0]=tr[0].L[1]=N+5;
	tr[0].R[0]=tr[0].R[1]=-1;
	
	rep(i,1,n){
		int x, y; read(x), read(y);
		pt[i]={x, y};
	}
	root=rebuild(1, n, 0);
	
	rep(i,1,m){
		int op, x, y; read(op), read(x), read(y);
		if(op==1) insert(root, {x, y}, 0);
		else{
			res=INF;
			query(root, {x, y});
			Write(res);
		}
	}

	return 0;
}

相關文章