樹套樹

highkj發表於2024-04-07

樹套樹

這裡主要介紹樹狀陣列套權值線段樹的方法,畢竟基本上所有的樹套樹題都能用這種方法解,並且時間複雜度都是 \(n\times (logn)^2\)

思路

這裡有一道例題。

【模板】樹套樹

題目描述

您需要寫一種資料結構(可參考題目標題),來維護一個有序數列,其中需要提供以下操作:

  1. 查詢 \(k\) 在區間內的排名

  2. 查詢區間內排名為 \(k\) 的值

  3. 修改某一位置上的數值

  4. 查詢 \(k\) 在區間內的前驅(前驅定義為嚴格小於 \(x\),且最大的數,若不存在輸出 -2147483647

  5. 查詢 \(k\) 在區間內的後繼(後繼定義為嚴格大於 \(x\),且最小的數,若不存在輸出 2147483647

輸入格式

第一行兩個數 \(n,m\),表示長度為 \(n\) 的有序序列和 \(m\) 個操作。

第二行有 \(n\) 個數,表示有序序列。

下面有 \(m\) 行,\(opt\) 表示操作標號。

\(opt=1\),則為操作 \(1\),之後有三個數 \(l~r~k\),表示查詢 \(k\) 在區間 \([l,r]\) 的排名。

\(opt=2\),則為操作 \(2\),之後有三個數 \(l~r~k\),表示查詢區間 \([l,r]\) 內排名為 \(k\) 的數。

\(opt=3\),則為操作 \(3\),之後有兩個數 \(pos~k\),表示將 \(pos\) 位置的數修改為 \(k\)

\(opt=4\),則為操作 \(4\),之後有三個數 \(l~r~k\),表示查詢區間 \([l,r]\)\(k\) 的前驅。

\(opt=5\),則為操作 \(5\),之後有三個數 \(l~r~k\),表示查詢區間 \([l,r]\)\(k\) 的後繼。

輸出格式

對於操作 \(1,2,4,5\),各輸出一行,表示查詢結果。

樣例 #1

樣例輸入 #1

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

樣例輸出 #1

2
4
3
4
9

看完題目後可以發現這是一道樹套樹,然後下文主要講解如何使用這棵樹套樹。

顧名而思義,就是用樹狀陣列的方式來維護權值線段樹(動態開點),我們對於上述的 \(5\) 個操作分別來看一下如何實現。

  • 操作 \(1\) 查詢 \(l\sim r\)\(k\) 的排名,我們會只放在權值線段樹上的做法,這裡就是會多維護 \(2\) 個陣列,就是和普通的樹狀陣列一樣,將每一次的 \(l,r\) 都存下來,然後在查詢中用 \(r\) 的總和減去 \(l\) 的即可,記住在往另一個地方遞迴時要更新這兩個陣列。

    int rk1(int l,int r,int k) {
    	if(l==r) {
    		return 1;
    	}
    	int mid=(l+r)/2,sum=false;
    	rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;//和普通的樹狀陣列相同
    	rep(i,1,cp) sum+=tr[tr[p[i]].l].sum; 
    	if(mid>=k) {
    		rep(i,1,cs) s[i]=tr[s[i]].l;//向那一邊遞迴也要將 l,r 陣列改一下
    		rep(i,1,cp) p[i]=tr[p[i]].l;
    		return rk1(l,mid,k);
    	}else{
    		rep(i,1,cs) s[i]=tr[s[i]].r;
    		rep(i,1,cp) p[i]=tr[p[i]].r;
    		return sum+rk1(mid+1,r,k);
    	} 
    }
    l--;//用 r 的減去 l-1 的就為 l~r 中的
    cs=cp=false;//清空
    for(;l;l-=lowbit(l)) s[++cs]=rt[l];//與樹狀陣列模板一樣
    for(;r;r-=lowbit(r)) p[++cp]=rt[r];
    
  • 對於操作二,其實和 \(1\) 的實現過程一樣,就是在普通權值線段樹上加上了 \(l,r\) 陣列的改變而已。

    int Ans(int l,int r,int k) {
    	if(l==r) return l;
    	int mid=(l+r)>>1;
    	int sum=false;
    	rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;//同理
    	rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;//同理
    	if(k<=sum) {
    		rep(i,1,cs) s[i]=tr[s[i]].l;//改變
    		rep(i,1,cp) p[i]=tr[p[i]].l;
    		return Ans(l,mid,k);
    	}else {
    		rep(i,1,cs) s[i]=tr[s[i]].r;//改變
    		rep(i,1,cp) p[i]=tr[p[i]].r;
    		return Ans(mid+1,r,k-sum);
        }
    }
    l--;//用 r 的減去 l-1 的就為 l~r 中的
    cs=cp=false;//清空
    for(;l;l-=lowbit(l)) s[++cs]=rt[l];//與樹狀陣列模板一樣
    for(;r;r-=lowbit(r)) p[++cp]=rt[r];
    
  • 操作三是最簡單的直接修改即可,這裡可以直接結合樹狀陣列的方式直接將每一個都 modify 一下即可。

    void modify(int &u,int l,int r,int k,int cnt) {
    	if(!u) u=++idx;//動態開點
    	tr[u].sum+=cnt;//加上
    	if(l==r) return;
    	int mid=(l+r)/2;
    	if(mid>=k) modify(tr[u].l,l,mid,k,cnt);
    	else modify(tr[u].r,mid+1,r,k,cnt);
    }
    in(l),in(k);
    for(int i=l;i<=n;i+=lowbit(i)) modify(rt[i],0,Max,a[l],-1);//先減後加
    a[l]=k;
    for(int i=l;i<=n;i+=lowbit(i)) modify(rt[i],0,Max,a[l],1);
    
  • 操作四,這裡我不會直接轉移所以用了一下二分一下排名,直接看排名為 \(mid\) 的數是否小於 \(k\) 即可。

    in(l),in(r),in(k);
    int L=1,R=r-l+1,res=false;
    while(L<=R) {
    	int mid=L+R>>1;
    	cs=cp=false;
    	for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
    	for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
    	if(Ans(0,Max,mid)<k) res=mid,L=mid+1;
    	else R=mid-1;
    }
    if(!res) {
    	cout<<"-2147483647\n";
    	continue;
    }
    cs=cp=false;
    for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
    for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
    cout<<Ans(0,Max,res)<<endl;
    
  • 操作五同理就是將小於改為大於即可。

總程式碼

#include <bits/stdc++.h>
using namespace std;
#define rep(i,x,y) for(register int i=x;i<=y;i++)
#define rep1(i,x,y) for(register int i=x;i>=y;--i)
#define in(x) scanf("%d",&x)
#define ll long long
#define fire signed
#define il inline
il void print(int x) {
	if(x<0) putchar('-'),x=-x;
	if(x>=10) print(x/10);
	putchar(x%10+'0');
}
int T;
const int N=5e4+10;
struct node{
	int l,r;
	int sum;
}tr[N*2*16*16];
const int Max=1e8+1;
int n,m,idx;
int cp,cs;
int p[N],s[N];
void modify(int &u,int l,int r,int k,int cnt) {
	if(!u) u=++idx;
	tr[u].sum+=cnt;
	if(l==r) return;
	int mid=(l+r)/2;
	if(mid>=k) modify(tr[u].l,l,mid,k,cnt);
	else modify(tr[u].r,mid+1,r,k,cnt);
}
int rt[N],a[N];
int lowbit(int x) {
	return x&-x;
}
int rk(int l,int r,int k) {
	if(l==r) {
		int sum=false;
		rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
		rep(i,1,cp) sum+=tr[tr[p[i]].l].sum; 
		if(!sum) sum=1;
		return sum;
	}
	int mid=(l+r)/2,sum=false;
	rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
	rep(i,1,cp) sum+=tr[tr[p[i]].l].sum; 
	if(mid>=k) {
		rep(i,1,cs) s[i]=tr[s[i]].l;
		rep(i,1,cp) p[i]=tr[p[i]].l;
		return rk(l,mid,k);
	}else{
		rep(i,1,cs) s[i]=tr[s[i]].r;
		rep(i,1,cp) p[i]=tr[p[i]].r;
		return sum+rk(mid+1,r,k);
	} 
}
int rk2(int l,int r,int k) {
	if(l==r) {
		int sum=false;
		rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
		rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;
		return sum;
	}
	int mid=(l+r)/2,sum=false;
	rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
	rep(i,1,cp) sum+=tr[tr[p[i]].l].sum; 
	if(mid>=k) {
		rep(i,1,cs) s[i]=tr[s[i]].l;
		rep(i,1,cp) p[i]=tr[p[i]].l;
		return rk2(l,mid,k);
	}else{
		rep(i,1,cs) s[i]=tr[s[i]].r;
		rep(i,1,cp) p[i]=tr[p[i]].r;
		return sum+rk2(mid+1,r,k);
	} 
}
int rk1(int l,int r,int k) {
	if(l==r) {
		return 1;
	}
	int mid=(l+r)/2,sum=false;
	rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
	rep(i,1,cp) sum+=tr[tr[p[i]].l].sum; 
	if(mid>=k) {
		rep(i,1,cs) s[i]=tr[s[i]].l;
		rep(i,1,cp) p[i]=tr[p[i]].l;
		return rk1(l,mid,k);
	}else{
		rep(i,1,cs) s[i]=tr[s[i]].r;
		rep(i,1,cp) p[i]=tr[p[i]].r;
		return sum+rk1(mid+1,r,k);
	} 
}
int Ans(int l,int r,int k) {
	if(l==r) return l;
	int mid=(l+r)>>1;
	int sum=false;
	rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
	rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;
	if(k<=sum) {
		rep(i,1,cs) s[i]=tr[s[i]].l;
		rep(i,1,cp) p[i]=tr[p[i]].l;
		return Ans(l,mid,k);
	}else {
		rep(i,1,cs) s[i]=tr[s[i]].r;
		rep(i,1,cp) p[i]=tr[p[i]].r;
		return Ans(mid+1,r,k-sum);
	}
}
void solve() {
//	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	in(n),in(m);
	rep(i,1,n) {
		in(a[i]);
		for(int j=i;j<=n;j+=lowbit(j)) modify(rt[j],0,Max,a[i],1);
	}
	while(m--) {
		int opt;
		int l,r,k;
		in(opt);
		if(opt==1) {
			in(l),in(r),in(k);
			cs=cp=false;
			l--;
			for(;l;l-=lowbit(l)) s[++cs]=rt[l];
			for(;r;r-=lowbit(r)) p[++cp]=rt[r];
			cout<<rk1(0,Max,k)<<endl;
		}else if(opt==2){
			in(l),in(r),in(k);
			cs=cp=false;
			l--;
			for(;l;l-=lowbit(l)) s[++cs]=rt[l];
			for(;r;r-=lowbit(r)) p[++cp]=rt[r];
			cout<<Ans(0,Max,k)<<endl;
		}else if(opt==3) {
			in(l),in(k);
			for(int i=l;i<=n;i+=lowbit(i)) modify(rt[i],0,Max,a[l],-1);
			a[l]=k;
			for(int i=l;i<=n;i+=lowbit(i)) modify(rt[i],0,Max,a[l],1);
		}else if(opt==4) {
			in(l),in(r),in(k);
			int L=1,R=r-l+1,res=false;
			while(L<=R) {
				int mid=L+R>>1;
				cs=cp=false;
				for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
				for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
				if(Ans(0,Max,mid)<k) res=mid,L=mid+1;
				else R=mid-1;
			}
			if(!res) {
				cout<<"-2147483647\n";
				continue;
			}
			cs=cp=false;
			for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
			for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
			cout<<Ans(0,Max,res)<<endl;
		}else {
			in(l),in(r),in(k);
			int L=1,R=r-l+1,res=false;
			while(L<=R) {
				int mid=L+R>>1;
				cs=cp=false;
				for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
				for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
				if(Ans(0,Max,mid)>k) res=mid,R=mid-1;
				else L=mid+1;
			}
			if(res==0) cout<<"2147483647\n";
			else {
				cs=cp=false;
				for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
				for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
				cout<<Ans(0,Max,res)<<endl;
			}
		}
	}
	return;
}
fire main() {
	T=1;
	while(T--) {
		solve();
	}
	return false;
}

相關文章