P4688 Ynoi2016 掉進兔子洞
經典莫隊加 bitset
。
思路
不難發現最終答案就是:
\[(r_1-l_1+1)+(r_2-l_2+1)+(r_3-l_3+1)-3\times size
\]
其中 \(size\) 表示 3 個區間內出現了多少個公共元素。
看到這麼多區間,不妨有把區間拆下來搞莫隊的想法。
先不考慮詢問個數的限制,我們考慮使用 bitset
維護出現多少個公共元素。
然而 bitset
維護出來的是多少種而不是個數。
又然而我們可以先將序列離散化,離散化時每個元素的新值賦為小於等於它的元素的個數。
在莫隊加入一個節點時,把 bitset
中的第 \(p-cnt_p\) 位標為 \(1\)。
\(cnt_p\) 為當前區間內 \(p\) 元素的個數。
然後你就會發現,bitset
中不同的值的元素所儲存位置是不同的(這個顯而易見)。
然後你又發現,bitset
中不同區域的 \(1\) 的個數代表了某些值相等的元素的個數。
接著你把 3 個區間分別的 bitset
求交,統計 \(1\) 的個數,就求出了 \(size\)。
然而一次性不可以處理這麼多的區間,我們把詢問分組處理即可。
時間複雜度 \(\frac{n^2\sqrt n}{w}\),然而卡不到莫隊上限。
CODE
#include<bits/stdc++.h>
using namespace std;
#define M 2e4
const int maxn=1e5+5,maxm=2e4+5;
struct qry
{
int l,r,t;
}q[maxm*3];
int n,m,tot=1;
int a[maxn],cnt[maxn],nans[maxm];
bitset<maxn>ans[maxm],nb;
map<int,int>mp;
inline bool cmp1(qry a,qry b){return a.l<b.l;}
inline bool cmp2(qry a,qry b){return a.r<b.r;}
inline void ins(int a){nb[a-cnt[a]]=1;cnt[a]++;}
inline void del(int a){cnt[a]--;nb[a-cnt[a]]=0;}
inline void solve()
{
if(tot>=m) return ;
int tp=0;
for(int i=1;i<=M&&tot<=m;i++,tot++)
{
++tp;scanf("%d%d",&q[tp].l,&q[tp].r);q[tp].t=i;nans[i]+=q[tp].r-q[tp].l+1;
++tp;scanf("%d%d",&q[tp].l,&q[tp].r);q[tp].t=i;nans[i]+=q[tp].r-q[tp].l+1;
++tp;scanf("%d%d",&q[tp].l,&q[tp].r);q[tp].t=i;nans[i]+=q[tp].r-q[tp].l+1;
}
for(int i=1;i<=tp/3;i++) ans[i].set();
sort(q+1,q+tp+1,cmp1);
for(int i=1;i<=tp;i+=320)
{
int r=min(i+319,tp);
sort(q+i,q+r+1,cmp2);
}
int nl=0,nr=0;
for(int i=1;i<=tp;i++)
{
if(nr<q[i].l)
{
for(int j=nl;j<=nr;j++) del(a[j]);
nl=q[i].l,nr=q[i].r;
for(int j=nl;j<=nr;j++) ins(a[j]);
}
else
{
while(nl<q[i].l) del(a[nl]),nl++;
while(nl>q[i].l) nl--,ins(a[nl]);
while(nr<q[i].r) nr++,ins(a[nr]);
while(nr>q[i].r) del(a[nr]),nr--;
}
ans[q[i].t]&=nb;
}
for(int i=nl;i<=nr;i++) del(a[i]);
for(int i=1;i<=tp/3;i++) printf("%lld\n",nans[i]-ans[i].count()*3);
for(int i=1;i<=tp/3;i++) nans[i]=0;tp=0;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]),mp[a[i]]++;
map<int,int>::iterator it2,it1;
for(it1=mp.begin(),it2=it1,it2++;it2!=mp.end();it1++,it2++) it2->second+=it1->second;
for(int i=1;i<=n;i++) a[i]=mp[a[i]];
for(int i=1;i<=5;i++) solve();
}