最近學的

lying_up發表於2024-07-10

fhq-treap

定義

給一棵二叉搜尋樹的每一個節點隨機賦了個值 \(pri\) , 滿足父親節點的 \(pri\) 小於其子節點的 \(pri\) , 相當於一個小根堆。

操作

分裂

即把 treap 一分為二。分為兩種 , 按值分裂和按子樹大小分裂。

按值分裂

把 treap 分為兩棵 , 一棵全部 \(\le\) key , 另一棵全部 \(>\) key , 因為 treap 滿足二叉搜尋樹的性質 , 節點左兒子一定小於本身,右兒子大於本身。所以遇到一個節點小於 key 時 , 直接把該節點和其左子樹分裂出去 , 再繼續遞迴分裂右子樹。

inline void split(int p,int key,int &l,int &r){
	if(!p) return l=r=0,void();
	if(tr[p].val<=key) l=p,split(tr[p].r,key,tr[p].r,r);
	else r=p,split(tr[p].l,key,l,tr[p].l);
}

按子樹大小分裂

分裂為兩棵 treap , 使一棵子樹大小為 \(sz\) , 和按值分裂是同理的。

inline void split(int p,int sz,int &l,int &r){
	if(!p) return l=r=0,void();
	if(tr[tr[p].l].sz<sz){
	    sz-=tr[tr[p].l].sz+1; l=p;
            split(tr[p].r,sz,tr[p].r,r);
	}
	else{
	   r=p; 
	   split(tr[p].l,sz,l,tr[p].l);
	}
}

合併

把兩棵 treap 合併為一棵 , 注意要滿足兩棵 treap 的值域不交 , 即一棵的值全部 \(\le\) 另一棵 , 而且還要注意 \(pri\) 的關係,實現比較簡單。

inline int merge(int l,int r){
	if(!l || !r) return l+r;
	if(tr[l].pri<tr[r].pri){
	    tr[l].r=merge(tr[l].r,r);
	    return l;
	}
	else{
	    tr[r].l=merge(l,tr[r].l);
	    return r;
	}
}

時間複雜度

樹的高度相當於隨機序列的笛卡爾樹的高度 , 為 \(O(\log n)\) , 分裂和合並複雜度顯然為樹高 , 為 \(O(\log n)\)

例題

P3369

考慮用 treap 維護序列。

  • 操作一直接按 \(x\) 分裂為 \(L,R\) 兩顆 treap , 然後再合併 \(L,x,R\)

  • 操作二按 \(x\) 分裂為 \(L,R\) 兩顆 treap ,將 \(L\)\(x-1\) 分為 \(L,M\), 合併 \(M\) 的左右子樹 , 再和 \(L,R\) 合併 , 就相當於吞了個 \(x\)

  • 操作三按 \(x\) 分裂後查子樹大小

  • 操作四類似於線段樹上二分 , 從根開始每次看當前排名是否 \(>\) 左子樹大小 , 是就減去跳右子樹 , 不是就跳到左子樹

  • 操作五按 \(x-1\) 分裂為 \(L,R\) 後 , 由於二叉搜尋樹的性質 , 一直跳 \(L\) 的右子樹 , 直到跳不了

  • 操作六同操作五

點選檢視程式碼
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i=(a);i<=(n);++i)
#define per(i,a,n) for(int i=(n);i>=(a);--i)
#define SZ(x) ((int)(x).size())

using namespace std;
typedef double db;
typedef long long ll;
typedef pair<int,int> pii;

const int maxn=5e5+10;

mt19937 rnd(time(0));

struct node{
	int l,r,val,sz,pri;
}tr[maxn];

int cur,rt;

inline int newnode(int x){
	tr[++cur]={0,0,x,1,(int)rnd()};
	return cur;
}

inline void upd(int p){
	tr[p].sz=1;
	if(tr[p].l) tr[p].sz+=tr[tr[p].l].sz;
	if(tr[p].r) tr[p].sz+=tr[tr[p].r].sz;
}

inline void split(int p,int key,int &l,int &r){
	if(!p) return l=r=0,void();
	if(tr[p].val<=key) l=p,split(tr[p].r,key,tr[p].r,r);
	else r=p,split(tr[p].l,key,l,tr[p].l);
	upd(p);
}

inline int merge(int l,int r){
	if(!l || !r) return l+r;
	if(tr[l].pri<tr[r].pri){
		tr[l].r=merge(tr[l].r,r); upd(l);
		return l;
	}
	else{
		tr[r].l=merge(l,tr[r].l); upd(r);
		return r;
	}
}

inline int kth(int k){
	int u=rt;
	while(true){
		if(tr[tr[u].l].sz+1==k) return tr[u].val;
		if(tr[tr[u].l].sz+1<k) k-=tr[tr[u].l].sz+1,u=tr[u].r; 
		else u=tr[u].l; 
	}
}

inline int pre(int x){
	int l,r,u;
	split(rt,x-1,l,r);
	u=l;
	while(true){
		if(!tr[u].r){
			rt=merge(l,r);
			return tr[u].val;
		} 
		u=tr[u].r;
	}
}

inline int suf(int x){
	int l,r,u;
	split(rt,x,l,r);
	u=r;
	while(true){
		if(!tr[u].l){
			rt=merge(l,r);
			return tr[u].val;
		} 
		u=tr[u].l;
	}
}

int T;

inline void solve(){
	scanf("%d",&T);
	while(T--){
		int opt,x,l,r,mid;
		scanf("%d%d",&opt,&x);
		if(opt==1){
			int id=newnode(x);
			split(rt,x,l,r);
			rt=merge(merge(l,id),r);
		}
		if(opt==2){
			split(rt,x,l,r);
			split(l,x-1,l,mid);
			rt=merge(merge(merge(l,tr[mid].l),tr[mid].r),r);
		}
		if(opt==3){
			split(rt,x-1,l,r);
			printf("%d\n",tr[l].sz+1);
			rt=merge(l,r);
		}
		if(opt==4) printf("%d\n",kth(x));
		if(opt==5) printf("%d\n",pre(x));
		if(opt==6) printf("%d\n",suf(x));
	}			
	
	
}

signed main(){
    int _=1;
    //scanf("%d",&_); 
    while(_--) solve();
}

P5338

給上個題差不多 , 不過換成了 pair 而已

點選檢視程式碼
#pragma G++ optimize(2)
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i=(a);i<=(n);++i)
#define per(i,a,n) for(int i=(n);i>=(a);--i)
#define SZ(x) ((int)(x).size())

using namespace std;
typedef double db;
typedef long long ll;
typedef pair<int,int> pii;

const int maxn=2e6+10;

mt19937 rnd(time(0));

struct node{
	int l,r,sz,pri;
	pii val;
}tr[maxn];

int cur,rt;

inline int newnode(pii x){
	tr[++cur]={0,0,1,(int)rnd(),x};
	return cur;
}

inline void upd(int p){
	tr[p].sz=1;
	if(tr[p].l) tr[p].sz+=tr[tr[p].l].sz;
	if(tr[p].r) tr[p].sz+=tr[tr[p].r].sz;
}

inline void split(int p,pii key,int &l,int &r){
	if(!p) return l=r=0,void();
	if(tr[p].val<=key) l=p,split(tr[p].r,key,tr[p].r,r);
	else r=p,split(tr[p].l,key,l,tr[p].l);
	upd(p);
}

inline int merge(int l,int r){
	if(!l || !r) return l+r;
	if(tr[l].pri<tr[r].pri){
		tr[l].r=merge(tr[l].r,r); upd(l);
		return l;
	}
	else{
		tr[r].l=merge(l,tr[r].l); upd(r);
		return r;
	}
}

inline void dbg(int p){
	if(!p) return;
	if(tr[p].l) dbg(tr[p].l);
	printf("%d %d\n",tr[p].val.first,tr[p].val.second);
	if(tr[p].r) dbg(tr[p].r);

}
typedef unsigned int ui;
ui randNum(ui& seed, ui last, const ui m){ 
    seed = seed * 17 + last;
    return seed % m + 1; 
}

template<typename T> 
inline void pr(T x,bool op=false){
    x<0?x=-x,putchar('-'):0;
    static short sta[25],top(0);
    do sta[top++]=x%10,x/=10; while(x);
    while(top) putchar(sta[--top]|48);
    op?putchar('\n'):putchar(' ');
} 

int n,m;
pii a[maxn];
ui seed,lst;

inline void solve(){
	cur=rt=0;
	scanf("%d%d%u",&n,&m,&seed);	
	rep(i,1,n){
		a[i]={0,0};
		int pos=newnode(a[i]),l,r;
		split(rt,a[i],l,r);
		rt=merge(merge(l,pos),r);
	} 
	while(m--){
		int l,r,mid,id=randNum(seed,lst,n),cnt=randNum(seed,lst,n);
		// printf("----->%d %d\n",id,cnt);
		pii tmp=a[id]; tmp.second--;
		split(rt,a[id],l,r); //<=a[id] >a[id]
		split(l,tmp,l,mid);
		a[id].first--; a[id].second+=cnt;		
		rt=merge(merge(merge(l,tr[mid].l),tr[mid].r),r);
		tr[mid].val=a[id]; tr[mid].l=tr[mid].r=0;
		tr[mid].sz=1;
		tmp=a[id]; tmp.second--;
		split(rt,tmp,l,r);
		pr(lst=tr[l].sz,1);
		rt=merge(merge(l,mid),r);
		// puts("ok");
		// dbg(rt);
	}

}


signed main(){
	lst=7;
    int _=1;
    scanf("%d",&_); 
    while(_--) solve();
}

P3391

用 treap 維護序列 , 使得樹的中序遍歷為序列 , 其實就是相當於把下標作為值。注意到區間反轉在樹上為交換左右子樹 , 按子樹大小分裂然後打懶標記即可。

點選檢視程式碼
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i=(a);i<=(n);++i)
#define per(i,a,n) for(int i=(n);i>=(a);--i)
#define SZ(x) ((int)(x).size())

using namespace std;
typedef double db;
typedef long long ll;
typedef pair<int,int> pii;

const int maxn=1e5+10;

mt19937 rnd(time(0));

struct node{
	int l,r,sz,pri,val;
	bool rev;
}tr[maxn];

int cur,rt;

inline int newnode(int x){
	tr[++cur]={0,0,1,(int)rnd(),x,false};
	return cur;
}

inline void settag(int p){
	tr[p].rev^=1;
	swap(tr[p].l,tr[p].r);
}

inline void down(int p){
	if(tr[p].rev){
		settag(tr[p].l); settag(tr[p].r);
		tr[p].rev=false;
	}
}

inline void up(int p){
	tr[p].sz=1;
	if(tr[p].l) tr[p].sz+=tr[tr[p].l].sz;
	if(tr[p].r) tr[p].sz+=tr[tr[p].r].sz;
}

inline void split(int p,int sz,int &l,int &r){
	if(!p) return l=r=0,void();
	down(p);
	if(tr[tr[p].l].sz<sz){
		sz-=tr[tr[p].l].sz+1; l=p;
		split(tr[p].r,sz,tr[p].r,r);
	}
	else{
		r=p; 
		split(tr[p].l,sz,l,tr[p].l);
	}
	up(p);
}

inline int merge(int l,int r){
	if(!l || !r) return l+r;
	if(tr[l].pri<tr[r].pri){
		down(l);
		tr[l].r=merge(tr[l].r,r);
		up(l);
		return l;
	}
	else{
		down(r);
		tr[r].l=merge(l,tr[r].l);
		up(r);
		return r;
	}
}

inline void dbg(int p){
	if(!p) return;
	down(p);
	if(tr[p].l) dbg(tr[p].l);
	printf("%d ",tr[p].val);
	if(tr[p].r) dbg(tr[p].r);

}

int n,m;

inline void solve(){
	scanf("%d%d",&n,&m);
	rep(i,1,n){
		int id=newnode(i);
		rt=merge(rt,i);
	}	
	while(m--){
		int a,b,l,r,mid;
		scanf("%d%d",&a,&b);
		split(rt,b,l,r);
		split(l,a-1,l,mid);
		settag(mid);
		rt=merge(merge(l,mid),r);
	}
	dbg(rt);
}

signed main(){
    int _=1;
    //scanf("%d",&_); 
    while(_--) solve();
}

P4146

比上一題多了些維護的東西而已,寫法差不多。

點選檢視程式碼
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i=(a);i<=(n);++i)
#define per(i,a,n) for(int i=(n);i>=(a);--i)
#define SZ(x) ((int)(x).size())

using namespace std;
typedef double db;
typedef long long ll;
typedef pair<int,int> pii;

const int maxn=5e4+10;

mt19937 rnd(time(0));
struct node{
	int l,r,pri,sz;
	ll val,tag,mx;
	bool rev;
}tr[maxn];
int cur,rt;

inline int newnode(int x){
	tr[++cur]={0,0,(int)rnd(),1,x,0,x,false};
	return cur;
}

inline void settag1(int p){
	if(!p) return;
	swap(tr[p].l,tr[p].r);
	tr[p].rev^=1;
}

inline void settag2(int p,ll v){
	if(!p) return;
	tr[p].mx+=v;
	tr[p].val+=v;
	tr[p].tag+=v;
}

inline void up(int p){
	tr[p].sz=1+tr[tr[p].l].sz+tr[tr[p].r].sz;
	tr[p].mx=tr[p].val;
	if(tr[p].l) tr[p].mx=max(tr[p].mx,tr[tr[p].l].mx);
	if(tr[p].r) tr[p].mx=max(tr[p].mx,tr[tr[p].r].mx);
}

inline void down(int p){
	if(tr[p].rev){
		settag1(tr[p].l); settag1(tr[p].r);
		tr[p].rev=false;
	}
	if(tr[p].tag){
		settag2(tr[p].l,tr[p].tag); settag2(tr[p].r,tr[p].tag);
		tr[p].tag=0;
	}
}

inline void split(int p,int sz,int &l,int &r){
	if(!p) return l=r=0,void();
	down(p);
	if(tr[tr[p].l].sz<sz){
		sz-=tr[tr[p].l].sz+1; l=p;
		split(tr[p].r,sz,tr[p].r,r);
	}
	else r=p,split(tr[p].l,sz,l,tr[p].l);
	up(p);
}

inline int merge(int l,int r){
	if(!l || !r) return l+r;
	if(tr[l].pri<tr[r].pri){
		down(l);
		tr[l].r=merge(tr[l].r,r);
		up(l);
		return l;
	}
	else{
		down(r);
		tr[r].l=merge(l,tr[r].l);
		up(r);
		return r;
	}
}

int n,m;

inline void solve(){
	scanf("%d%d",&n,&m);
	rep(i,1,n){
		int id=newnode(0);
		rt=merge(rt,id);
	}
	while(m--){
		int opt,l,r,v,L,M,R;
		scanf("%d%d%d",&opt,&l,&r);
		split(rt,r,L,R);
		split(L,l-1,L,M);
		if(opt==1){
			scanf("%d",&v);
			settag2(M,v);
		}
		else 
			if(opt==2) settag1(M);
			else printf("%lld\n",tr[M].mx);
		rt=merge(merge(L,M),R);
	}
	
	
}

signed main(){
    int _=1;
    //scanf("%d",&_); 
    while(_--) solve();
}

P2042

維護的東西多了很多,但原理還是一樣的。

點選檢視程式碼
#include<bits/stdc++.h>
#define rep(i,a,n) for(int i=(a);i<=(n);++i)
#define per(i,a,n) for(int i=(n);i>=(a);--i)
#define SZ(x) ((int)(x).size())

using namespace std;
typedef double db;
typedef long long ll;
typedef pair<int,int> pii;

const int maxn=5e5+10;
const int inf=1e9;

mt19937 rnd(time(0));

struct node{
	int l,r,val,pri,sz;
	bool rev;
	int pre,suf,sum,mss,tag;
}tr[maxn];

stack<int> trash;
int rt,cur;

inline int newnode(int x){
	int p;
	if(SZ(trash)) p=trash.top(),trash.pop();
	else p=++cur;
	if(x>=0) tr[p]={0,0,x,(int)rnd(),1,false,x,x,x,x,inf};
	else tr[p]={0,0,x,(int)rnd(),1,false,0,0,x,x,inf};
	return p; 
}

inline void settag1(int p){
	if(!p) return;
	tr[p].rev^=1;
	swap(tr[p].l,tr[p].r);
	swap(tr[p].pre,tr[p].suf);
}

inline void settag2(int p,int x){
	if(!p) return;
	tr[p].tag=tr[p].val=x; 
	if(x>=0) tr[p].pre=tr[p].suf=tr[p].mss=tr[p].sz*x;
	else tr[p].pre=tr[p].suf=0,tr[p].mss=x;
	tr[p].sum=tr[p].sz*x;
}

inline void up(int p){
	if(!p) return;
	int l=tr[p].l,r=tr[p].r;
	tr[p].sz=1+tr[l].sz+tr[r].sz;
	tr[p].sum=tr[l].sum+tr[r].sum+tr[p].val;
	tr[p].pre=max(tr[l].pre,tr[l].sum+tr[p].val+tr[r].pre);
	tr[p].suf=max(tr[r].suf,tr[r].sum+tr[p].val+tr[l].suf);
	tr[p].mss=max({tr[l].mss,tr[r].mss,tr[l].suf+tr[r].pre+tr[p].val});
}

inline void down(int p){
	if(!p) return;
	if(tr[p].rev){
		settag1(tr[p].l); settag1(tr[p].r);
		tr[p].rev=false;
	}
	if(tr[p].tag!=inf){
		settag2(tr[p].l,tr[p].tag); settag2(tr[p].r,tr[p].tag);
		tr[p].tag=inf;
	}
}

inline void split(int p,int sz,int &l,int &r){
	if(!p) return l=r=0,void();
	down(p);
	if(tr[tr[p].l].sz<sz){
		sz-=tr[tr[p].l].sz+1; l=p;
		split(tr[p].r,sz,tr[p].r,r);
	} 
	else{
		r=p;
		split(tr[p].l,sz,l,tr[p].l);
	}
	up(p);
}

inline int merge(int l,int r){
	if(!l || !r) return l+r;
	if(tr[l].pri < tr[r].pri){
		down(l);
		tr[l].r=merge(tr[l].r,r);
		up(l);
		return l;
	}
	else{
		down(r);
		tr[r].l=merge(l,tr[r].l);
		up(r);
		return r;
	}
}

inline void dispose(int p){
	if(!p) return;
	trash.push(p);
	if(tr[p].l) dispose(tr[p].l);
	if(tr[p].r) dispose(tr[p].r);
	tr[p].l=tr[p].r=0;
}

inline void dbg(int p){
	if(!p) return;
	if(tr[p].l) dbg(tr[p].l);
	printf("%d %d\n",p,tr[p].val);
	if(tr[p].r) dbg(tr[p].r);

}

int n,m,a[maxn];
char s[30];

inline int build(int l,int r){
	if(l==r) return newnode(a[l]);
	else{
		int mid=(l+r)>>1;
		return merge(build(l,mid),build(mid+1,r));
	}
}


inline void solve(){
	tr[0].mss=-inf;
	scanf("%d%d",&n,&m);
	rep(i,1,n){
		scanf("%d",&a[i]);
		int id=newnode(a[i]);
		rt=merge(rt,id);
	} 
	// dbg(rt);
	// puts("");
	while(m--){
		scanf("%s",s+1);		
		if(s[1]=='I'){
			int pos,tot,l,mid,r,x;
			scanf("%d%d",&pos,&tot);
			split(rt,pos,l,r);
			rep(i,1,tot) scanf("%d",&a[i]);
			mid=build(1,tot);
			rt=merge(merge(l,mid),r);
		}
		if(s[1]=='D'){
			int pos,tot,l,mid,r;
			scanf("%d%d",&pos,&tot);
			split(rt,pos+tot-1,l,r);
			split(l,pos-1,l,mid);
			dispose(mid);
			rt=merge(l,r);
		}
		if(s[1]=='M' && s[3]=='K'){
			int pos,tot,c,l,mid,r;
			scanf("%d%d%d",&pos,&tot,&c);
			split(rt,pos+tot-1,l,r);
			split(l,pos-1,l,mid);
			settag2(mid,c);
			rt=merge(merge(l,mid),r);
		}
		if(s[1]=='R'){
			int pos,tot,l,mid,r;
			scanf("%d%d",&pos,&tot);
			split(rt,pos+tot-1,l,r);
			split(l,pos-1,l,mid);
			settag1(mid);
			rt=merge(merge(l,mid),r);
		}
		if(s[1]=='G'){
			int pos,tot,l,mid,r;
			scanf("%d%d",&pos,&tot);
			split(rt,pos+tot-1,l,r);
			split(l,pos-1,l,mid);
			printf("%d\n",tr[mid].sum);
			rt=merge(merge(l,mid),r);
		}
		if(s[1]=='M' && s[3]=='X'){
			printf("%d\n",tr[rt].mss);
		} 
		// dbg(rt);
		// puts("");
	}
}

signed main(){
    int _=1;
    //scanf("%d",&_); 
    while(_--) solve();
}

相關文章