[筆記]AVL樹

Sinktank發表於2024-06-16

AVL樹是一種嚴格平衡的二叉搜尋樹,任何操作結束後,都能保證每個節點的左右子樹高度相差不超過\(1\)
內容源自BV1rt411j7Ff - 【AgOHの資料結構】平衡樹專題之叄 樹旋轉與AVL樹

模板題:P3369 【模板】普通平衡樹

結構體定義 & 基本函式

struct node{
	int l;//左孩子
    int r;//右孩子
    int v;//值
    int hei;//高度,葉節點為1
    int siz;//大小
}avl[N];
int cnt;//當前用到哪一個節點了,用於新建節點
int root;//根節點
//新建節點
void newnode(int &u,int v){
	avl[u=++cnt].v=v;//賦值
	avl[cnt].siz=1;//葉子節點
}
//更新節點資訊
void update(int u){
	avl[u].siz=avl[avl[u].l].siz+avl[avl[u].r].siz+1;//左+右+1
	avl[u].hei=max(avl[avl[u].l].hei,avl[avl[u].r].hei)+1;//max(左,右)+1
}

左右旋轉

AVL樹用旋轉來維護樹的平衡。旋轉分左旋和右旋:
[筆記]AVL樹

//左旋
void lrot(int &u){
	int r=avl[u].r;
	avl[u].r=avl[r].l;
	avl[r].l=u;
	u=r;
	update(avl[u].l),update(u);
}
//右旋
void rrot(int &u){
	int l=avl[u].l;
	avl[u].l=avl[l].r;
	avl[l].r=u;
	u=l;
	update(avl[u].r),update(u);
}

接下來我們需要判斷並處理AVL樹的不平衡情況。

//計算平衡因子(即左子樹高度-右子樹高度)
int factor(int u){
	return avl[avl[u].l].hei-avl[avl[u].r].hei;
}

對於樹上的節點\(u\)(假定\(u\)的子樹都平衡),其不平衡狀態有\(4\)種:

  • LL:\(u\)的左子樹過高,而左子節點的左子樹較高。
    處理方法:右旋一次\(u\)
  • LR:\(u\)的左子樹過高,而左子結點的右子樹較高。
    處理方法:設\(v\)\(u\)的左兒子,先左旋\(v\)(轉化成LL),再右旋\(u\)
  • RR:\(u\)的右子樹過高,而右子節點的右子樹較高。
    處理方法:左旋一次\(u\)
  • RL:\(u\)的右子樹過高,而右子節點的左子樹較高。
    處理方法:設\(v\)\(u\)的右兒子,先右旋\(v\)(轉化成RR),再左旋\(u\)
[筆記]AVL樹

若左子節點的左右子樹高度相同,則既可以歸納為LL,也可以作為LR考慮。右子節點同理。

//檢查並調整為平衡狀態,並更新節點的資訊
void check(int &u){
	int uf=factor(u);
	if(uf>1){
		int lf=factor(avl[u].l);
		if(lf>0) rrot(u);//LL
		else lrot(avl[u].l),rrot(u);//LR
	}else if(uf<-1){
		int rf=factor(avl[u].r);
		if(rf<0) lrot(u);//RR
		else rrot(avl[u].r),lrot(u);//RL
	}else if(u) update(u);//如果原本就平衡,且u不為空,就要更新
}

其他操作

和普通的BST一樣了。

//插入
void ins(int &u,int v){
	if(!u) newnode(u,v);
	else if(v<avl[u].v) ins(avl[u].l,v);
	else ins(avl[u].r,v);
	check(u);//自下向上更新節點資訊&調整結構
}
//找u的後繼(即u先往右走,再不斷往左直到沒有左子結點)v,
//讓v的父節點直接連線v的右子樹
int find(int &u,int fa){
	int ans;
	if(!avl[u].l){//終點
		ans=u;
		avl[fa].l=avl[u].r;
	}else{
		ans=find(avl[u].l,u);
		check(u);
	}
	return ans;
}
//刪除
void del(int &u,int v){
	if(v==avl[u].v){
		int l=avl[u].l,r=avl[u].r;
		if(!l||!r) u=l+r;
		else{
			u=find(r,r);//u的後繼v來替代u的位置
			avl[u].l=l;//v成為子樹的根,連線左邊
			if(u!=r) avl[u].r=r;//連線右邊
		}
	}else if(v<avl[u].v) del(avl[u].l,v);
	else del(avl[u].r,v);
	check(u);//自下向上更新節點資訊&調整結構
}
//計算v的排名(小於v的個數+1)
int getrank(int v){
	int u=root,ran=1;
	while(u){
		if(v<=avl[u].v) u=avl[u].l;
		else{
			ran+=avl[avl[u].l].siz+1;
			u=avl[u].r;
		}
	}
	return ran;
}
//計算第ran名
int getnum(int ran){
	int u=root;
	while(u){
		if(avl[avl[u].l].siz+1==ran) break;
		else if(avl[avl[u].l].siz>=ran)
			u=avl[u].l;
		else
			ran-=avl[avl[u].l].siz+1,u=avl[u].r;
	}
	return avl[u].v;
}
//前驅
int pre(int x){return getnum(getrank(x)-1);}
//後繼
int nex(int x){return getnum(getrank(x+1));}

Code

點選檢視程式碼
#include<bits/stdc++.h>
#define int long long
#define N 100010
using namespace std;
struct node{
	int l,r,v,hei,siz;
}avl[N];
int t,cnt,root;
void newnode(int &u,int v){
	avl[u=++cnt].v=v;
	avl[cnt].siz=1;
}
void update(int u){
	avl[u].siz=avl[avl[u].l].siz+avl[avl[u].r].siz+1;
	avl[u].hei=max(avl[avl[u].l].hei,avl[avl[u].r].hei)+1;
}
int factor(int u){
	return avl[avl[u].l].hei-avl[avl[u].r].hei;
}
void lrot(int &u){
	int r=avl[u].r;
	avl[u].r=avl[r].l;
	avl[r].l=u;
	u=r;
	update(avl[u].l),update(u);
}
void rrot(int &u){
	int l=avl[u].l;
	avl[u].l=avl[l].r;
	avl[l].r=u;
	u=l;
	update(avl[u].r),update(u);
}
void check(int &u){
	int uf=factor(u);
	if(uf>1){
		int lf=factor(avl[u].l);
		if(lf>=0) rrot(u);//LL
		else lrot(avl[u].l),rrot(u);//LR
	}else if(uf<-1){
		int rf=factor(avl[u].r);
		if(rf<=0) lrot(u);//RR
		else rrot(avl[u].r),lrot(u);//RL
	}else if(u) update(u);
}
void ins(int &u,int v){
	if(!u) newnode(u,v);
	else if(v<avl[u].v) ins(avl[u].l,v);
	else ins(avl[u].r,v);
	check(u);
}
int find(int &u,int fa){
	int ans;
	if(!avl[u].l){//終點
		ans=u;
		avl[fa].l=avl[u].r;
	}else{
		ans=find(avl[u].l,u);
		check(u);
	}
	return ans;
}
void del(int &u,int v){
	if(v==avl[u].v){
		int l=avl[u].l,r=avl[u].r;
		if(!l||!r) u=l+r;
		else{
			u=find(r,r);//找u的後繼,即比u大的第一個數
			avl[u].l=l;
			if(u!=r) avl[u].r=r;
		}
	}else if(v<avl[u].v) del(avl[u].l,v);
	else del(avl[u].r,v);
	check(u);
}
int getrank(int v){
	int u=root,ran=1;
	while(u){
		if(v<=avl[u].v) u=avl[u].l;
		else{
			ran+=avl[avl[u].l].siz+1;
			u=avl[u].r;
		}
	}
	return ran;
}
int getnum(int ran){
	int u=root;
	while(u){
		if(avl[avl[u].l].siz+1==ran) break;
		else if(avl[avl[u].l].siz>=ran)
			u=avl[u].l;
		else
			ran-=avl[avl[u].l].siz+1,u=avl[u].r;
	}
	return avl[u].v;
}
int pre(int x){return getnum(getrank(x)-1);}
int nex(int x){return getnum(getrank(x+1));}
signed main(){
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cin>>t;
	while(t--){
		int op,x;
		cin>>op>>x;
		if(op==1) ins(root,x);
		else if(op==2) del(root,x);
		else if(op==3) cout<<getrank(x)<<"\n";
		else if(op==4) cout<<getnum(x)<<"\n";
		else if(op==5) cout<<pre(x)<<"\n";
		else if(op==6) cout<<nex(x)<<"\n";
	}
	return 0;
}

附:相同節點合併寫法

當時看完影片,想到是不是能把相同節點計數,存在一個節點中。

於是就寫出下面的程式碼了。結構體多存了一個\(cnt\),然後newnodeupdateinsdelgetrankgetnum函式需要做相應的修改。

點選檢視程式碼
#include<bits/stdc++.h>
#define int long long
#define N 100010
using namespace std;
struct node{
	int l,r,v,hei,siz,cnt;
}avl[N];
int t,cnt,root;
void newnode(int &u,int v){
	avl[u=++cnt].v=v;
	avl[cnt].siz=1;
	avl[cnt].cnt=1;
}
void update(int u){
	avl[u].siz=avl[avl[u].l].siz+avl[avl[u].r].siz+avl[u].cnt;
	avl[u].hei=max(avl[avl[u].l].hei,avl[avl[u].r].hei)+1;
}
int factor(int u){
	return avl[avl[u].l].hei-avl[avl[u].r].hei;
}
void lrot(int &u){
	int r=avl[u].r;
	avl[u].r=avl[r].l;
	avl[r].l=u;
	u=r;
	update(avl[u].l),update(u);
}
void rrot(int &u){
	int l=avl[u].l;
	avl[u].l=avl[l].r;
	avl[l].r=u;
	u=l;
	update(avl[u].r),update(u);
}
void check(int &u){
	int uf=factor(u);
	if(uf>1){
		int lf=factor(avl[u].l);
		if(lf>=0) rrot(u);//LL
		else lrot(avl[u].l),rrot(u);//LR
	}else if(uf<-1){
		int rf=factor(avl[u].r);
		if(rf<=0) lrot(u);//RR
		else rrot(avl[u].r),lrot(u);//RL
	}else if(u) update(u);
}
void ins(int &u,int v){
	if(!u) newnode(u,v);
	else if(v==avl[u].v) avl[u].cnt++;
	else if(v<avl[u].v) ins(avl[u].l,v);
	else ins(avl[u].r,v);
	check(u);
}
int find(int &u,int fa){
	int ans;
	if(!avl[u].l){//終點
		ans=u;
		avl[fa].l=avl[u].r;
	}else{
		ans=find(avl[u].l,u);
		check(u);
	}
	return ans;
}
void del(int &u,int v){
	if(v==avl[u].v){
		if(avl[u].cnt>1) avl[u].cnt--;
		else{
			int l=avl[u].l,r=avl[u].r;
			if(!l||!r) u=l+r;
			else{
				u=find(r,r);//找u的後繼,即比u大的第一個數
				avl[u].l=l;
				if(u!=r) avl[u].r=r;
			}
		}
	}else if(v<avl[u].v) del(avl[u].l,v);
	else del(avl[u].r,v);
	check(u);
}
int getrank(int v){//小於自己的個數+1
	int u=root,ran=1;
	while(u){
		if(v<=avl[u].v) u=avl[u].l;
		else{
			ran+=avl[avl[u].l].siz+avl[u].cnt;
			u=avl[u].r;
		}
	}
	return ran;
}
int getnum(int ran){
	int u=root;
	while(u){
		if(avl[avl[u].l].siz+1<=ran&&ran<=avl[avl[u].l].siz+avl[u].cnt) break;
		//如果ran在[siz[l]+1,siz[l]+cnt[u]]的區間內,就說明第ran名就是u
		else if(avl[avl[u].l].siz>=ran)
			u=avl[u].l;
		else
			ran-=avl[avl[u].l].siz+avl[u].cnt,u=avl[u].r;
	}
	return avl[u].v;
}
int pre(int x){return getnum(getrank(x)-1);}
int nex(int x){return getnum(getrank(x+1));}
signed main(){
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	cin>>t;
	while(t--){
		int op,x;
		cin>>op>>x;
		if(op==1) ins(root,x);
		else if(op==2) del(root,x);
		else if(op==3) cout<<getrank(x)<<"\n";
		else if(op==4) cout<<getnum(x)<<"\n";
		else if(op==5) cout<<pre(x)<<"\n";
		else if(op==6) cout<<nex(x)<<"\n";
	}
	return 0;
}

兩種寫法效率相當,不合並183ms,合併190ms。
似乎相同節點合併反而更慢?

相關文章