[演算法] 資料結構 splay(伸展樹)解析

Last_Breath發表於2021-05-31

前言

splay學了已經很久了,只不過一直沒有總結,鴿了好久來寫一篇總結。

先介紹 splay:亦稱伸展樹,為二叉搜尋樹的一種,部分操作能在 \(O( \log n)\) 內完成,如插入、查詢、刪除、查詢序列第 \(k\) 大、查詢字首(比查詢的數小的數中最大的數)、查詢字尾(比查詢的數大的數中最小的數)等操作,甚至能夠實現區間平移。它由 Daniel Sleator 和 Robert Endre Tarjan 在1985年發明的。注:時間複雜度是均攤為 \(O(\log n)\) ,是經過嚴謹的證明的,單個操作可能退化成 \(O(n)\)

本文例題連結

演算法思想

先做一個小小的引入:輸入法中,你經常使用詞語,會在詞條中靠前的位置。實現過程可以使用 splay。

splay 是二叉搜尋樹的一種,這裡簡單介紹一下二叉搜尋樹。

對於一棵二叉樹,滿足樹上任意節點,它的左子樹上任意節點滿足比當前節點的權值小,右子樹上任意節點的權值比當前節點的權值大。則稱這棵樹為二叉搜尋樹。

可以利用二叉搜尋樹的性質來進行操作,比當前節點的權值小就在左子樹查詢,權值大就在右子樹查詢。

理想狀態下,若該二叉樹為一顆完全二叉樹,則單次操作時間複雜度為 \(O(\log n)\) 。但這顆二叉樹可能退化成一條鏈,這樣單次時間複雜度為 \(O(n)\)

splay 樹在這上面進行了改進,通過不斷改變樹的形態來保證不會退化,均攤時間複雜度為 \(O(\log n)\) 。基本思想是把搜尋頻率高的點放在深度小的位置,為了操作方便,可以認為每次操作的點都是頻率高的。常常把操作的點,或是操作區間的兩個端點放在根或根的附近的位置,那麼會涉及到旋轉操作。

根據勢能函式分析(我不會),splay 的時間複雜度上限為 \(O((m+n)\log n)\) ,但這個上限是有波動的。

基本操作

建議配合註釋一起使用。

結構體中應包含以下資訊:

struct Splay_Node {
	int son[2], val, cnt, siz, fa;
//分別是:兩個兒子,權值,副本數,子樹大小,父親節點
	#define ls t[pos].son[0] //巨集定義左兒子,方便一些
	#define rs t[pos].son[1] //右兒子,同上
};

簡單說明一下,副本數為權值為 val 的數的個數。

New

開闢新節點,裡面的值隨需求變化,以下是幾個重要的值。

int New(int val, int fa) {
	t[++tot].fa = fa, t[tot].cnt = t[tot].siz = 1, t[tot].val = val;
	return tot;
}

Build

建立splay樹,將極小值置為根節點,極大值作為根節點的右兒子,滿足二叉搜尋樹的性質,程式碼:

void Build() {
	root = New(-INF, 0); //極小值為根節點 
	t[root].son[1] = New(INF, root); //極大值為右兒子 
}

寫這段程式碼的主要原因是:使得 splay 的每個節點不會爆掉邊界,否則很容易就 RE 。

Ident

判斷該節點為父節點的左兒子還是右兒子,左兒子為 \(0\) ,右兒子為 \(1\)

bool Ident(int pos) { return t[t[pos].fa].son[1] == pos; } 

Update

更新子樹大小,還更新節點資訊(由需求所定)。

void Update(int pos) {
	t[pos].siz = t[ls].siz + t[rs].siz + t[pos].cnt; //子樹大小為左右子樹大小加上自己的副本數
}

Connect

將一對點變為父子關係。

void Connect(int pos, int fa, int flag) {//依次是:子節點,父節點,哪個兒子
	t[fa].son[flag] = pos;//將fa的兒子置為pos
	t[pos].fa = fa;//將pos的父親置為fa
}

Rotate

既然要把一個點旋轉到根節點,那麼就必須先掌握單旋操作,具體分兩個情況討論。

左兒子旋轉至父節點

在這裡插入圖片描述

如上圖,需要進行幾次轉換: \(x\) 的左兒子變為 \(y\) 的右兒子, \(y\) 的右兒子變為\(x\)\(a\) 的子節點變為 \(y\)

那麼程式可以寫為:

void Rotate(int pos) {//這裡的flag1=0,可以按照上述的三個轉換進行驗證這段程式是對的
	int fa = t[pos].fa, grand = t[fa].fa;
	int flag1 = Ident(pos), flag2 = Ident(fa);
	Connect(pos, grand, flag2);
	Connect(t[pos].son[flag1 ^ 1], fa, flag1);
	Connect(fa, pos, flag1 ^ 1);
	Update(fa); Update(pos);
}

右兒子旋轉至父節點

可以視為上圖的逆操作: \(y\) 的右兒子變為 \(x\) 的左兒子, \(x\) 的左兒子變為\(y\)\(a\) 的子節點變為 \(x\)

那麼程式依舊可以寫為:

void Rotate(int pos) {//這裡的flag1=1,可以按照上述的三個轉換進行驗證這段程式是對的
	int fa = t[pos].fa, grand = t[fa].fa;
	int flag1 = Ident(pos), flag2 = Ident(fa);
	Connect(pos, grand, flag2);
	Connect(t[pos].son[flag1 ^ 1], fa, flag1);
	Connect(fa, pos, flag1 ^ 1);
	Update(fa); Update(pos);
}

綜上所述,Rotate 操作可以不用判斷左右節點,寫法為上述程式。

Splay

聽名字就知道,這是splay樹的核心操作。

函式 \(splay(pos,to)\) 定義為:將編號為 \(x\) 的節點,旋轉至父親為 \(to\) 的節點(即 \(to\) 的其中一個子節點,且進行 splay 後依然滿足二叉搜尋樹的性質)。

顯然有一種方法:對於當前節點 \(pos\) ,不停進行 \(Rotate(pos)\) ,知道 \(pos\) 的父節點為 \(to\) 為止。

但是這並不能使該 splay 樹的形態發生太大的改變。splay 的目的是改變樹的形態,有一種改進的方法:雙旋。順帶說明一下,單旋會被卡成 \(O(nm)\) 。(具體我也不知道怎麼卡)

雙旋即一次旋轉兩次,設當前點為 \(x\) ,父親節點為 \(y\) ,爺爺為 \(z\) 。具體分為兩種情況,這裡只證明正確性。

x、y、z 形成一條鏈

在這裡插入圖片描述

這種情況先單旋 \(y\) 在單旋 \(x\) 。過程見下圖:
在這裡插入圖片描述

顯然,在上述過程中,嚴謹地滿足了 \(val[x]>val[y]>val[z]\)

x、y、z 形成“<”或 “>”

直接進行兩次單旋操作,正確性顯然。

Code

程式碼很短,只有三行。

void Splay(int pos, int to) {
	for(int fa = t[pos].fa; t[pos].fa != to; Rotate(pos), fa = t[pos].fa)
		if(t[fa].fa != to) Ident(pos) == Ident(fa) ? Rotate(fa) : Rotate(pos);
//Ident(pos) == Ident(fa)意味著pos和fa成為了一條鏈的形狀,否則為“<”或“>”。
	if(!to) root = pos;//更新根節點,根節點的父親值為0
}

總結

這些是 splay 的基本操作,之後的所有操作都是建立在這些之上的。

引申操作

Find

定義 \(Find(val)\) :查詢權值為 \(val\) 的點的編號,若沒有該點就返回 \(0\)

利用 splay 為二叉搜尋樹的性質,若 \(val\) 小於當前節點的權值,則在左子樹中查詢;若大於則在右子樹中查詢。知道找到當前節點的編號為 \(0\) 或當前節點的權值等於 \(val\) 的時候返回改點的下標。

int Find(int pos, int val) {
	if(!pos) return 0;//空節點直接返回
	if(val == t[pos].val) return pos;//等於就直接返回節點編號
	else if(val < t[pos].val) return Find(ls, val);//在左子樹中查詢
	else return Find(rs, val);//在右子樹中查詢
}

Insert

即插入操作, 需要插入權值為 \(val\) 的值。

其思想跟 \(Find\) 函式差不多,利用二叉搜尋樹的性質直接就可以找到插入的位置。具體分為兩類:

  1. 有權值為 \(val\) 的點 \(pos\) ,直接使得副本數加 \(1\) 即可。
  2. 沒有權值為 \(val\) 的點 \(pos\) ,則開闢一個新的節點權值為 \(val\)

注意 \(pos\) 應傳實參,因為若開闢了新的節點,其父節點的對應兒子也需要改變。

void Insert(int &pos, int val, int fa) {//pos為實參
	if(!pos) Splay(pos = New(val, fa), 0);
	else if(val == t[pos].val) { ++t[pos].cnt; Splay(pos, 0); }
	else if(val < t[pos].val) Insert(ls, val, pos);
	else Insert(rs, val, pos);
}

Erase

即刪除操作, \(Erase(val)\) 定義為:刪除所維護的序列中權值為 \(val\) 的一個節點(如果有的話)。

可以先找到權值為 \(val\) 的節點並定義其編號為 \(pos\) ,分兩種情況。

  1. 若當前節點的副本數大於 \(1\) 時,即 \(t[pos].cnt>1\) 時,可以刪除其中一個副本即可,但並沒有刪除這個節點。
  2. 否則,則需要刪除該節點。需要先將 \(pos\) splay 到根節點。找到它的字首的編號 \(l\) 和它的字尾的編號 \(r\) ,則 \(t[l].val\leq val \leq t[r].val\) 。顯然, \((t[l].val,t[r].val)\) 區間內的數只有一個,即 \(pos\) 。將 \(l\) splay 至根節點, \(r\) splay 至 \(l\) 的右兒子,則 \(pos\) 必會在 \(r\) 的左兒子處,因為 \(l\)\(r\)\(pos\) 必回滿足二叉搜尋樹的性質。然後直接刪除 \(r\) 的左兒子即可。
void Erase(int val) {
	int pos = Find(root, val);//找到權值為 val 的點。
	if(!pos) return;//沒有改節點直接返回,沒有難倒刪空氣?
	if(t[pos].cnt > 1) { --t[pos].cnt; Splay(pos, 0); return; }//對應情況1
	Splay(pos, 0);
	int l = ls, r = rs;
	while(t[l].son[1]) l = t[l].son[1];//找到字首
	while(t[r].son[0]) r = t[r].son[0];//找到字尾
	Splay(l, 0); Splay(r, l);//對應情況2
	t[r].son[0] = 0;
}

這裡在提供一種做法,與 \(Find\) 函式的做法類似,可以說是其的升級版。總體框架不變,主要是針對第二種情況,將其旋轉到根節點在進行刪除,這種寫法還是比較常見的。

void Erase(int pos, int val) {
	if(!pos) return;
	if(val == t[pos].val) {
		if(t[pos].cnt > 1) { t[pos].cnt--; Splay(pos, 0); return; }
		if(ls) Rotate(ls), Erase(pos, val);//有左兒子跟左兒子交換
		else if(rs) Rotate(rs), Erase(pos, val);//有右兒子就跟右兒子交換
		else {//沒有兒子就直接刪除,注意必須刪除其父親的對應兒子
			int newroot = t[pos].fa;
			t[t[pos].fa].son[Ident(pos)] = 0;
			Splay(newroot, 0);
		}
		return;
	}
	else if(val < t[pos].val) rase(ls, val);
	else Erase(rs, val);
}

Query_kth

查詢 \(val\) 在序列是第幾大的樹,即按照從小到大的順序排序後, \(val\) 的排名,沒有 \(val\) 輸出返回 \(-1\)

程式碼使用遞迴實現,考慮對於當前節點 \(pos\) ,比 \(val\) 小的數都在左子樹內,即有 \(t[ls].siz\) 個樹比 \(t[pos].val\) 小。

對於區域性解,可以將 \(Querykth(pos,val)\) 函式理解為 \(pos\) 的子樹中,小於 \(val\) 的值有多少。

則可以分為三種情況來討論。

  1. \(val=t[pos].val\) 時,即找到了該節點,返回比它小的數的個數即可,即左子樹的節點數加 \(1\)
  2. \(val<t[pos].val\) 時, \(val\) 左子樹中,在左子樹中查詢該節點的排名。
  3. \(val>t[pos].val\) 時, 是最麻煩的部分。 \(val\) 右子樹中,左子樹與當前節點都會為答案做貢獻,先將其統計至答案中,在求出右子樹對於答案的貢獻。

注意,最後的答案是包含了極小值的,所以找到後的答案應該減一,這一部分我寫在了主函式裡,所以沒找到會輸出 \(-1\)

int Query_kth(int pos, int val) {
	if(!pos) return 0;//沒有輸出-1
	if(val == t[pos].val) { int res = t[ls].siz + 1; Splay(pos, 0); return res; }//對應情況1
	else if(val < t[pos].val) return Query_kth(ls, val);//對於情況2
	//下兩行程式碼對應情況3
	int res = t[ls].siz + t[pos].cnt;//找到後splay維護形態會導致子樹的大小變化,因此先記錄答案
	return Query_kth(rs, val) + res;
}

Query_val

查詢區間的第 \(k\) 小的數。

可以看做上一個操作的逆操作吧,若 \(k\) 都大於了區間的所有數的個數,就直接返回極大值。

同樣,對於區域性解,可以將 \(Queryval(pos,k)\) 函式理解為 \(pos\) 的子樹中,第 \(k\) 大值為多少。

又可以分為三個情況:

  1. \(t[ls].siz\geq k\) 時,即所求答案在左子樹,在左邊查詢即可。
  2. \(t[ls].siz+t[pos].cnt\geq k\) 時, 答案為 \(t[pos].val\) ,因為第 \(t[ls].siz+1\) 小至 \(t[ls].siz+t[pos].cnt\) 的數全部權值都為 \(t[pos].val\)
  3. 否則,答案全部會在右子樹當中,查詢右子樹第 \(k-t[ls].siz-t[pos].cnt\) 大,因為當前節點與左兒子一定比右子樹任何一個數小。

同樣的需要注意,最後的答案是包含了極小值的,同樣這一部分我寫在了主函式裡,查詢的時候需要查詢第 \(k+1\) 大的那個數。

int Query_val(int pos, int rank) {
	if(!pos) return INF;
	if(t[ls].siz >= rank) return Query_val(ls, rank);
	else if(t[ls].siz + t[pos].cnt >= rank) { Splay(pos, 0); return t[pos].val; }
	return Query_val(rs, rank - t[ls].siz - t[pos].cnt);
}

Get_Pre、Get_Nxt

\(Erase\) 操作中提到過,可以使用那樣的做法。

亦可使用在文末的程式碼中稍快的做法,與 \(Find\) 函式相似,這裡就不多說了。(其實是不想打字了

也可以參照這段程式碼將一些操作寫為非遞迴的寫法,會更快一些。

總結

有些細心的同學可能已經發現了,幾乎每個操作都有 splay 操作來維護當前樹的形態,保證時間複雜度。

C++程式碼

只是將上述操作拼起來放在一個程式碼裡。

說明一下操作的幾種型別:

  1. 插入 \(x\) 數。
  2. 刪除 \(x\) 數(若有多個相同的數,因只刪除一個)。
  3. 查詢 \(x\) 數的排名(排名定義為比當前數小的數的個數 \(+1\) )。
  4. 查詢排名為 \(x\) 的數。
  5. \(x\) 的前驅(前驅定義為小於 \(x\),且最大的數)。
  6. \(x\) 的後繼(後繼定義為大於 \(x\),且最小的數)。

不是特別長,實現的方法也並不困難,打的時候必須得注意,完整沒附上註釋的程式碼:

#include <cstdio>
namespace Quick_Function {
	template <typename Temp> void Read(Temp &x) {
		x = 0; char ch = getchar(); bool op = 0;
		while(ch < '0' || ch > '9') { if(ch == '-') op = 1; ch = getchar(); }
		while(ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
		if(op) x = -x;
	}
	template <typename T, typename... Args> void Read(T &t, Args &... args) { Read(t); Read(args...); }
	template <typename Temp> Temp Max(Temp x, Temp y) { return x > y ? x : y; }
	template <typename Temp> Temp Min(Temp x, Temp y) { return x < y ? x : y; }
	template <typename Temp> Temp Abs(Temp x) { return x < 0 ? (-x) : x; }
	template <typename Temp> void Swap(Temp &x, Temp &y) { x ^= y ^= x ^= y; }
}
using namespace Quick_Function;
#define INF 0x3f3f3f3f
const int MAXN = 1e6 + 5;
int n;
struct Splay_Node {
	int son[2], val, cnt, siz, fa;
	#define ls t[pos].son[0]
	#define rs t[pos].son[1]
};
struct Splay_Tree {
	int root, tot;
	Splay_Node t[MAXN];
	bool Ident(int pos) { return t[t[pos].fa].son[1] == pos; } 
	int New(int val, int fa) {
		t[++tot].fa = fa, t[tot].cnt = t[tot].siz = 1, t[tot].val = val;
		return tot;
	}
	void Build() { root = New(-INF, 0); t[root].son[1] = New(INF, root); }
	void Update(int pos) { t[pos].siz = t[ls].siz + t[rs].siz + t[pos].cnt; }
	void Connect(int pos, int fa, int flag) { t[fa].son[flag] = pos, t[pos].fa = fa; }
	void Rotate(int pos) {
		int fa = t[pos].fa, grand = t[fa].fa;
		int flag1 = Ident(pos), flag2 = Ident(fa);
		Connect(pos, grand, flag2);
		Connect(t[pos].son[flag1 ^ 1], fa, flag1);
		Connect(fa, pos, flag1 ^ 1);
		Update(fa); Update(pos);
	}
	void Splay(int pos, int to) {
		for(int fa = t[pos].fa; t[pos].fa != to; Rotate(pos), fa = t[pos].fa)
			if(t[fa].fa != to) Ident(pos) == Ident(fa) ? Rotate(fa) : Rotate(pos);
		if(!to) root = pos;
	}
	int Find(int pos, int val) {
		if(!pos) return 0;
		if(val == t[pos].val) return pos;
		else if(val < t[pos].val) return Find(ls, val);
		else return Find(rs, val);
	}
	void Insert(int &pos, int val, int fa) {
		if(!pos) Splay(pos = New(val, fa), 0);
		else if(val == t[pos].val) { ++t[pos].cnt; Splay(pos, 0); }
		else if(val < t[pos].val) Insert(ls, val, pos);
		else Insert(rs, val, pos);
	}
	void Erase(int val) {
		int pos = Find(root, val);
		if(!pos) return;
		if(t[pos].cnt > 1) { --t[pos].cnt; Splay(pos, 0); return; }
		Splay(pos, 0);
		int l = ls, r = rs;
		while(t[l].son[1]) l = t[l].son[1];
		while(t[r].son[0]) r = t[r].son[0];
		Splay(l, 0); Splay(r, l);
		t[r].son[0] = 0;
	}
	int Query_kth(int pos, int val) {
		if(!pos) return 0;
		if(val == t[pos].val) { int res = t[ls].siz + 1; Splay(pos, 0); return res; }
		else if(val < t[pos].val) return Query_kth(ls, val);
		int res = t[ls].siz + t[pos].cnt;
		return Query_kth(rs, val) + res;
	}
	int Query_val(int pos, int rank) {
		if(!pos) return INF;
		if(t[ls].siz >= rank) return Query_val(ls, rank);
		else if(t[ls].siz + t[pos].cnt >= rank) { Splay(pos, 0); return t[pos].val; }
		return Query_val(rs, rank - t[ls].siz - t[pos].cnt);
	}
	int Get_Pre(int val) {
		int pos, res, newroot;
		pos = newroot = root;
		while(pos) {
			if(t[pos].val < val) { res = t[pos].val; pos = rs; }
			else pos = ls;
		}
		Splay(newroot, 0);
		return res;
	}
	int Get_Nxt(int val) {
		int pos, res, newroot;
		pos = newroot = root;
		while(pos) {
			if(t[pos].val > val) { res = t[pos].val; pos = ls; }
			else pos = rs;
		}
		Splay(newroot, 0);
		return res;
	}
};
Splay_Tree tree;
int main() {
	tree.Build(); Read(n); 
	for(int i = 1, opt, x; i <= n; i++) {
		Read(opt, x);
		if(opt == 1) tree.Insert(tree.root, x, 0);
		else if(opt == 2) tree.Erase(x);
		else if(opt == 3) printf("%d\n", tree.Query_kth(tree.root, x) - 1);
		else if(opt == 4) printf("%d\n", tree.Query_val(tree.root, x + 1));
		else if(opt == 5) printf("%d\n", tree.Get_Pre(x));
		else printf("%d\n", tree.Get_Nxt(x));
	}
	return 0;
}

相關文章