題意
給你一個長度為 $ n $ 有序數列 $ a $ ,進行 $ m $ 次操作,操作有如下幾種:
- 查詢 $ k $ 在區間 $ [l,r] $ 內的排名
- 查詢區間 $ [l,r] $ 內排名為 $ k $ 的值
- 將 $ a[p] $ 修改為 $ k $
- 查詢 $ k $ 在區間 $ [l,r] $ 內的前驅(前驅定義為小於 $ k $ ,且最大的數)
- 查詢 $ k $ 在區間 $ [l,r] $ 內的後繼(後繼定義為大於 $ k $ ,且最小的數)
題解
線段樹套splay。
先將 $ n $ 個數插入線段樹:對於每個 $ a[i] $,將線段樹上到位置 $ i $ 的葉子節點的路徑上的所有splay插入元素 $ a[i] $ 。
操作1:區間 $ [l,r] $ 線上段樹上對應了若干棵splay,將這些splay中小於 $ k $ 的元素個數累加,記為 $ sum $ ,$ sum+1 $ 即為答案。
操作2:二分這個元素的值,然後進行操作1得到當前rank,對應地調整答案。
操作3:將線段樹上到位置 $ p $ 的葉子節點的路徑上的所有splay刪除 $ a[p] $ ,再插入 $ k $ ,然後更新 $ a[p] = k $ 。
操作4:將區間 $ [l,r] $ 對應的所有splay中查詢到的 $ k $ 的前驅取 $ max $ 即可。
操作5:將區間 $ [l,r] $ 對應的所有splay中查詢到的 $ k $ 的後繼取 $ min $ 即可。
最後,紀念一下我用pbds加map封裝的的假splay......QAQ
還有就是因為b站g++版本太老,null_type
會CE,要改成null_mapped_type
。
AC Code
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define MAX_N 50005
#define MAX_V 200005
#define INF 2147483647
using namespace std;
using namespace __gnu_pbds;
typedef tree<pair<int,int>,null_mapped_type,less<pair<int,int> >,rb_tree_tag,tree_order_statistics_node_update> Tree;
typedef Tree::iterator git;
struct Splay
{
Tree t;
map<int,int> mp;
void ins(int x)
{
t.insert(make_pair(x,mp[x]=mp[x]+1));
}
void del(int x)
{
t.erase(make_pair(x,mp[x])),mp[x]=mp[x]-1;
}
int pre(int x)
{
if(t.empty()) return -INF;
git it=t.lower_bound(make_pair(x,0));
if(it==t.begin()) return -INF;
return (--it)->first;
}
int suc(int x)
{
if(t.empty()) return INF;
git it=t.upper_bound(make_pair(x,INF));
if(it==t.end()) return INF;
return it->first;
}
int kth(int x)
{
return t.find_by_order(x-1)->first;
}
int rk(int x)
{
return t.order_of_key(make_pair(x,1))+1;
}
};
int n,m;
int a[MAX_N];
Splay t[MAX_V];
void ins(int p,int k,int l,int r,int x)
{
t[k].ins(x);
if(l==r) return;
int mid=(l+r)>>1;
if(p<=mid) ins(p,k*2+1,l,mid,x);
else ins(p,k*2+2,mid+1,r,x);
}
int getrk(int a,int b,int k,int l,int r,int x)
{
if(a<=l && r<=b) return t[k].rk(x)-1;
int mid=(l+r)>>1,ans=0;
if(a<=mid) ans+=getrk(a,b,k*2+1,l,mid,x);
if(b>mid) ans+=getrk(a,b,k*2+2,mid+1,r,x);
return ans;
}
int getx(int a,int b,int k)
{
int l=0,r=INF;
while(r-l>1)
{
int mid=(l+r)>>1;
if(getrk(a,b,0,1,n,mid)<=k-1) l=mid;
else r=mid;
}
return l;
}
void upd(int p,int k,int l,int r,int x)
{
t[k].del(a[p]),t[k].ins(x);
if(l==r) return;
int mid=(l+r)>>1;
if(p<=mid) upd(p,k*2+1,l,mid,x);
else upd(p,k*2+2,mid+1,r,x);
}
int getpre(int a,int b,int k,int l,int r,int x)
{
if(a<=l && r<=b) return t[k].pre(x);
int mid=(l+r)>>1,ans=-INF;
if(a<=mid) ans=max(ans,getpre(a,b,k*2+1,l,mid,x));
if(b>mid) ans=max(ans,getpre(a,b,k*2+2,mid+1,r,x));
return ans;
}
int getsuc(int a,int b,int k,int l,int r,int x)
{
if(a<=l && r<=b) return t[k].suc(x);
int mid=(l+r)>>1,ans=INF;
if(a<=mid) ans=min(ans,getsuc(a,b,k*2+1,l,mid,x));
if(b>mid) ans=min(ans,getsuc(a,b,k*2+2,mid+1,r,x));
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]),ins(i,0,1,n,a[i]);
int opt,l,r,k,p;
while(m--)
{
scanf("%d",&opt);
if(opt==1)
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",getrk(l,r,0,1,n,k)+1);
}
if(opt==2)
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",getx(l,r,k));
}
if(opt==3)
{
scanf("%d%d",&p,&k);
upd(p,0,1,n,k),a[p]=k;
}
if(opt==4)
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",getpre(l,r,0,1,n,k));
}
if(opt==5)
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",getsuc(l,r,0,1,n,k));
}
}
}