樹套樹
這裡主要介紹樹狀陣列套權值線段樹的方法,畢竟基本上所有的樹套樹題都能用這種方法解,並且時間複雜度都是 \(n\times (logn)^2\)。
思路
這裡有一道例題。
【模板】樹套樹
題目描述
您需要寫一種資料結構(可參考題目標題),來維護一個有序數列,其中需要提供以下操作:
-
查詢 \(k\) 在區間內的排名
-
查詢區間內排名為 \(k\) 的值
-
修改某一位置上的數值
-
查詢 \(k\) 在區間內的前驅(前驅定義為嚴格小於 \(x\),且最大的數,若不存在輸出
-2147483647
) -
查詢 \(k\) 在區間內的後繼(後繼定義為嚴格大於 \(x\),且最小的數,若不存在輸出
2147483647
)
輸入格式
第一行兩個數 \(n,m\),表示長度為 \(n\) 的有序序列和 \(m\) 個操作。
第二行有 \(n\) 個數,表示有序序列。
下面有 \(m\) 行,\(opt\) 表示操作標號。
若 \(opt=1\),則為操作 \(1\),之後有三個數 \(l~r~k\),表示查詢 \(k\) 在區間 \([l,r]\) 的排名。
若 \(opt=2\),則為操作 \(2\),之後有三個數 \(l~r~k\),表示查詢區間 \([l,r]\) 內排名為 \(k\) 的數。
若 \(opt=3\),則為操作 \(3\),之後有兩個數 \(pos~k\),表示將 \(pos\) 位置的數修改為 \(k\)。
若 \(opt=4\),則為操作 \(4\),之後有三個數 \(l~r~k\),表示查詢區間 \([l,r]\) 內 \(k\) 的前驅。
若 \(opt=5\),則為操作 \(5\),之後有三個數 \(l~r~k\),表示查詢區間 \([l,r]\) 內 \(k\) 的後繼。
輸出格式
對於操作 \(1,2,4,5\),各輸出一行,表示查詢結果。
樣例 #1
樣例輸入 #1
9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5
樣例輸出 #1
2
4
3
4
9
看完題目後可以發現這是一道樹套樹,然後下文主要講解如何使用這棵樹套樹。
顧名而思義,就是用樹狀陣列的方式來維護權值線段樹(動態開點),我們對於上述的 \(5\) 個操作分別來看一下如何實現。
-
操作 \(1\) 查詢 \(l\sim r\) 中 \(k\) 的排名,我們會只放在權值線段樹上的做法,這裡就是會多維護 \(2\) 個陣列,就是和普通的樹狀陣列一樣,將每一次的 \(l,r\) 都存下來,然後在查詢中用 \(r\) 的總和減去 \(l\) 的即可,記住在往另一個地方遞迴時要更新這兩個陣列。
int rk1(int l,int r,int k) { if(l==r) { return 1; } int mid=(l+r)/2,sum=false; rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;//和普通的樹狀陣列相同 rep(i,1,cp) sum+=tr[tr[p[i]].l].sum; if(mid>=k) { rep(i,1,cs) s[i]=tr[s[i]].l;//向那一邊遞迴也要將 l,r 陣列改一下 rep(i,1,cp) p[i]=tr[p[i]].l; return rk1(l,mid,k); }else{ rep(i,1,cs) s[i]=tr[s[i]].r; rep(i,1,cp) p[i]=tr[p[i]].r; return sum+rk1(mid+1,r,k); } } l--;//用 r 的減去 l-1 的就為 l~r 中的 cs=cp=false;//清空 for(;l;l-=lowbit(l)) s[++cs]=rt[l];//與樹狀陣列模板一樣 for(;r;r-=lowbit(r)) p[++cp]=rt[r];
-
對於操作二,其實和 \(1\) 的實現過程一樣,就是在普通權值線段樹上加上了 \(l,r\) 陣列的改變而已。
int Ans(int l,int r,int k) { if(l==r) return l; int mid=(l+r)>>1; int sum=false; rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;//同理 rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;//同理 if(k<=sum) { rep(i,1,cs) s[i]=tr[s[i]].l;//改變 rep(i,1,cp) p[i]=tr[p[i]].l; return Ans(l,mid,k); }else { rep(i,1,cs) s[i]=tr[s[i]].r;//改變 rep(i,1,cp) p[i]=tr[p[i]].r; return Ans(mid+1,r,k-sum); } } l--;//用 r 的減去 l-1 的就為 l~r 中的 cs=cp=false;//清空 for(;l;l-=lowbit(l)) s[++cs]=rt[l];//與樹狀陣列模板一樣 for(;r;r-=lowbit(r)) p[++cp]=rt[r];
-
操作三是最簡單的直接修改即可,這裡可以直接結合樹狀陣列的方式直接將每一個都 modify 一下即可。
void modify(int &u,int l,int r,int k,int cnt) { if(!u) u=++idx;//動態開點 tr[u].sum+=cnt;//加上 if(l==r) return; int mid=(l+r)/2; if(mid>=k) modify(tr[u].l,l,mid,k,cnt); else modify(tr[u].r,mid+1,r,k,cnt); } in(l),in(k); for(int i=l;i<=n;i+=lowbit(i)) modify(rt[i],0,Max,a[l],-1);//先減後加 a[l]=k; for(int i=l;i<=n;i+=lowbit(i)) modify(rt[i],0,Max,a[l],1);
-
操作四,這裡我不會直接轉移所以用了一下二分一下排名,直接看排名為 \(mid\) 的數是否小於 \(k\) 即可。
in(l),in(r),in(k); int L=1,R=r-l+1,res=false; while(L<=R) { int mid=L+R>>1; cs=cp=false; for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i]; for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i]; if(Ans(0,Max,mid)<k) res=mid,L=mid+1; else R=mid-1; } if(!res) { cout<<"-2147483647\n"; continue; } cs=cp=false; for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i]; for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i]; cout<<Ans(0,Max,res)<<endl;
-
操作五同理就是將小於改為大於即可。
總程式碼
#include <bits/stdc++.h>
using namespace std;
#define rep(i,x,y) for(register int i=x;i<=y;i++)
#define rep1(i,x,y) for(register int i=x;i>=y;--i)
#define in(x) scanf("%d",&x)
#define ll long long
#define fire signed
#define il inline
il void print(int x) {
if(x<0) putchar('-'),x=-x;
if(x>=10) print(x/10);
putchar(x%10+'0');
}
int T;
const int N=5e4+10;
struct node{
int l,r;
int sum;
}tr[N*2*16*16];
const int Max=1e8+1;
int n,m,idx;
int cp,cs;
int p[N],s[N];
void modify(int &u,int l,int r,int k,int cnt) {
if(!u) u=++idx;
tr[u].sum+=cnt;
if(l==r) return;
int mid=(l+r)/2;
if(mid>=k) modify(tr[u].l,l,mid,k,cnt);
else modify(tr[u].r,mid+1,r,k,cnt);
}
int rt[N],a[N];
int lowbit(int x) {
return x&-x;
}
int rk(int l,int r,int k) {
if(l==r) {
int sum=false;
rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;
if(!sum) sum=1;
return sum;
}
int mid=(l+r)/2,sum=false;
rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;
if(mid>=k) {
rep(i,1,cs) s[i]=tr[s[i]].l;
rep(i,1,cp) p[i]=tr[p[i]].l;
return rk(l,mid,k);
}else{
rep(i,1,cs) s[i]=tr[s[i]].r;
rep(i,1,cp) p[i]=tr[p[i]].r;
return sum+rk(mid+1,r,k);
}
}
int rk2(int l,int r,int k) {
if(l==r) {
int sum=false;
rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;
return sum;
}
int mid=(l+r)/2,sum=false;
rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;
if(mid>=k) {
rep(i,1,cs) s[i]=tr[s[i]].l;
rep(i,1,cp) p[i]=tr[p[i]].l;
return rk2(l,mid,k);
}else{
rep(i,1,cs) s[i]=tr[s[i]].r;
rep(i,1,cp) p[i]=tr[p[i]].r;
return sum+rk2(mid+1,r,k);
}
}
int rk1(int l,int r,int k) {
if(l==r) {
return 1;
}
int mid=(l+r)/2,sum=false;
rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;
if(mid>=k) {
rep(i,1,cs) s[i]=tr[s[i]].l;
rep(i,1,cp) p[i]=tr[p[i]].l;
return rk1(l,mid,k);
}else{
rep(i,1,cs) s[i]=tr[s[i]].r;
rep(i,1,cp) p[i]=tr[p[i]].r;
return sum+rk1(mid+1,r,k);
}
}
int Ans(int l,int r,int k) {
if(l==r) return l;
int mid=(l+r)>>1;
int sum=false;
rep(i,1,cs) sum-=tr[tr[s[i]].l].sum;
rep(i,1,cp) sum+=tr[tr[p[i]].l].sum;
if(k<=sum) {
rep(i,1,cs) s[i]=tr[s[i]].l;
rep(i,1,cp) p[i]=tr[p[i]].l;
return Ans(l,mid,k);
}else {
rep(i,1,cs) s[i]=tr[s[i]].r;
rep(i,1,cp) p[i]=tr[p[i]].r;
return Ans(mid+1,r,k-sum);
}
}
void solve() {
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
in(n),in(m);
rep(i,1,n) {
in(a[i]);
for(int j=i;j<=n;j+=lowbit(j)) modify(rt[j],0,Max,a[i],1);
}
while(m--) {
int opt;
int l,r,k;
in(opt);
if(opt==1) {
in(l),in(r),in(k);
cs=cp=false;
l--;
for(;l;l-=lowbit(l)) s[++cs]=rt[l];
for(;r;r-=lowbit(r)) p[++cp]=rt[r];
cout<<rk1(0,Max,k)<<endl;
}else if(opt==2){
in(l),in(r),in(k);
cs=cp=false;
l--;
for(;l;l-=lowbit(l)) s[++cs]=rt[l];
for(;r;r-=lowbit(r)) p[++cp]=rt[r];
cout<<Ans(0,Max,k)<<endl;
}else if(opt==3) {
in(l),in(k);
for(int i=l;i<=n;i+=lowbit(i)) modify(rt[i],0,Max,a[l],-1);
a[l]=k;
for(int i=l;i<=n;i+=lowbit(i)) modify(rt[i],0,Max,a[l],1);
}else if(opt==4) {
in(l),in(r),in(k);
int L=1,R=r-l+1,res=false;
while(L<=R) {
int mid=L+R>>1;
cs=cp=false;
for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
if(Ans(0,Max,mid)<k) res=mid,L=mid+1;
else R=mid-1;
}
if(!res) {
cout<<"-2147483647\n";
continue;
}
cs=cp=false;
for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
cout<<Ans(0,Max,res)<<endl;
}else {
in(l),in(r),in(k);
int L=1,R=r-l+1,res=false;
while(L<=R) {
int mid=L+R>>1;
cs=cp=false;
for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
if(Ans(0,Max,mid)>k) res=mid,R=mid-1;
else L=mid+1;
}
if(res==0) cout<<"2147483647\n";
else {
cs=cp=false;
for(int i=l-1;i;i-=lowbit(i)) s[++cs]=rt[i];
for(int i=r;i;i-=lowbit(i)) p[++cp]=rt[i];
cout<<Ans(0,Max,res)<<endl;
}
}
}
return;
}
fire main() {
T=1;
while(T--) {
solve();
}
return false;
}