異或與區間加題解

zYzYzYzYz發表於2024-05-01

異或與區間加題解

簡要題意

給定 \(n,m,K,a_{1...n}\),和 \(m\) 個三元組 \((x_i,y_i,z_i)\),定義 \(calc(l,r)=a_l \bigoplus a_{l+1}\bigoplus ...\bigoplus a_r\)。對於每個三元組 \((x,y,z)\) ,對所有滿足 \(x\le l\le r\le y\ ,\ calc(l,r)=K\) 的區間 \((l,r)\) 內的每個數 \(b_i\) 加上 \(z\),其中 \(b_{1..n}\)​​ 初始全為 0。輸出對 \(2^{30}\) 取模。

\(0\le K,a_i<2^{30},1\le x\le y\le n,0\le z\le 10000\)

10 10 3//n m K
2 0 3 0 1 0 0 2 1 2//a[i]
1 10 1//x y z
3 10 9
10 10 5
4 10 10
9 10 8
7 7 8
3 5 10
7 8 9
7 9 7
7 8 7
1 4 54 53 52 72 99 126 114 39

題解

先來一個暴力的方法。首先容易想到對 \(a\) 求一遍字首和,將 \(calc(l,r)=K\) 轉化為 \(sum_r\bigoplus sum_{l-1}=K\)。將每個三元組按關鍵字排序(先x後y),然後從前往後掃描每一個區間。然後開一個樹狀陣列,令 \(c_{x..y}\) 加上 \(z\)\(c_{pos}\) 表示:對於每個右端點位於 \(pos\) 的區間 \((l,r)\),應對的 \(b_{l...r}\) 需要加上 \(c_{pos}\)。但是這樣可能會把 \(l<x\) 的區間也進行操作,所以我們應該從前往後掃描每一個位置。在掃描到位置 \(l\) 的時候,如果發現存在三元組滿足 \(x=l\),那麼我們令 \(c_{x..y}\) 加上 \(z\)。處理完 \(c\) 以後,找到滿足 \(sum_r\bigoplus sum_{l-1}=K\)\(r\)(可以利用map找),然後再將 \(b_{l...r}\) 加上 \(c_{r}\),這一個區間加可以用差分處理。

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=1<<30;
int n,m,K,a[150010],sum[150010];
LL c[150010],cc[150010];
unordered_map<int,vector<int>>mp;
struct SYZ
{int x,y,z;}syz[150010];
inline int read()
{
	int x=0,w=0;char ch=0;
	while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
	while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return w?-x:x;
}
bool cmp(SYZ n1,SYZ n2)
{
	if(n1.x^n2.x)return n1.x<n2.x;
	return n1.y<n2.y;
}
void change(int x,int y)
{for(;x<=n;x+=x&-x)c[x]+=y;}
int ask(int x,int y=0)
{for(;x;x-=x&-x)y+=c[x];return y;}
int main()
{
	n=read();m=read();K=read();
	mp[0].push_back(0);
	for(int i=1;i<=n;i++)
		mp[sum[i]=sum[i-1]^(a[i]=read())].push_back(i);
	for(int i=1;i<=m;i++){
		int x=read(),y=read(),z=read();
		syz[i]=(SYZ){x,y,z};
	}
	sort(syz+1,syz+1+m,cmp);
	for(auto&x:mp)//x.first是鍵,x.second是值
		reverse(x.second.begin(),x.second.end());
	for(int i=1,j=1;i<=n;i++){//i是左端點
		while(j<=m&&syz[j].x==i)
			change(1,syz[j].z),change(syz[j].y+1,-syz[j].z),j++;
		for(int x:mp[sum[i-1]^K]){//x是右端點
			if(x<i)break;
			int temp=ask(x);
			cc[i]+=temp;
			cc[x+1]-=temp;
		}
	}
	for(int i=1;i<=n;i++)
		printf("%lld%c",(cc[i]+=cc[i-1])%=mod," \n"[i==n]);
}

這裡的 \(cc\)\(b\) 的差分陣列。

此方法的瓶頸在於:對於一個 \(l\) ,滿足條件的 \(r\) 可能會非常多。

我們可以在當 \(r\) 的數量小於 \(\sqrt{n}\) 時用上述方法,當 \(r\) 數量過多時需要換一種方法。值得注意的是,這樣不同的 \(sum_r\) 不會超過 \(\sqrt{n}\) 個。

我們不妨對每一個這樣的 \(sum_r\) 單獨處理,我們先暴力找到所有的 \(l\)\(r\) ,利用字首和可以計算出區間 \(xx,yy\) 內有多少個 \(l\)\(r\)

\(prez_i\) 表示:滿足 \(x\le i\le y\) 的所有三元組的 \(z\) 的和,利用差分可以快速求出。

我們要分別掃描所有的 \(l\),\(r\) ,掃描 \(l\) 時,讓 \(cc_l\) 加上一些東西;掃描 \(r\) 時,讓 \(cc_{r+1}\) 減去一些東西。

我們不妨先考慮一個弱化版本,即:\(y=n\)。如果此方法可行的話,我們可以試圖將 \((x,y,z)\) 拆分成 \((x,n,z)\)\((y+1,n,-z)\)。注意,直接拆分會錯誤地統計上這樣的區間:\(x\le l\le y<r\)。我們需要額外的操作減去這樣的貢獻。

顯然,對於 \((x,y,z)\) 只需要 \(l\ge x\) 即可。我們可以掃描每一個 \(l\) ,計算 \(x\le l\)\(z\) 的和,以及 \(r\) 的數量。前者即是 \(prez_l\),後者用差分統計即可。有:\(cc_l+=cnt(r)*prez_l\)。同樣地,對於 \(r\) 我們沿用類似的方法,但稍微麻煩點。對於 \(cc_{r+1}\) 我們要減去的是:每一個 \(x\le l\) 對應的 \(z\)。這個可以一遍掃描一遍統計,初始令 \(temp=0\) ,從左往右掃描時,如果 \(pos\) 是左端點,則令 \(temp+=prez_{pos}\)。當掃描到一個右端點 \(r\) 時,令 \(cc_{r+1}-=temp\)。意思是遇到一個左端點 \(l\),那麼它後面的右端點統統加上它左邊的三元組的 \(z\) (即 \(prez_l\))。

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=1<<30;
int n,m,K,B,a[150010],sum[150010];
int sX[150010],sY[150010];
LL c[150010],cc[150010],prez[150010];
unordered_map<int,vector<int>>mp;
struct SYZ
{int x,y,z;}syz[150010];
inline int read()
{
	int x=0,w=0;char ch=0;
	while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
	while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return w?-x:x;
}
bool cmp(SYZ n1,SYZ n2)
{
	if(n1.x^n2.x)return n1.x<n2.x;
	return n1.y<n2.y;
}
void change(int x,int y)
{for(;x<=n;x+=x&-x)c[x]+=y;}
int ask(int x,int y=0)
{for(;x;x-=x&-x)y+=c[x];return y;}
void solve(int Y)//Y=sum[r]
{
	int X=K^Y;//X=sum[l-1]
	for(int i=1;i<=n;i++){
		sX[i]=sX[i-1]+(sum[i-1]==X);
		sY[i]=sY[i-1]+(sum[i]==Y);
	}
	for(int i=1;i<=n;i++)
	if(sum[i-1]==X)
		cc[i]+=prez[i]*(sY[n]-sY[i-1]);
	LL temp=0;
	for(int i=1;i<=n;i++){
		if(sum[i-1]==X)temp+=prez[i];
		if(sum[i]==Y)cc[i+1]-=temp;
	}
}
int main()
{
	n=read();m=read();K=read();B=sqrt(n);
	mp[0].push_back(0);
	for(int i=1;i<=n;i++)
		mp[sum[i]=sum[i-1]^(a[i]=read())].push_back(i);
	for(int i=1;i<=m;i++){
		int x=read(),y=read(),z=read();
		syz[i]=(SYZ){x,y,z};
		prez[x]+=z;prez[y+1]-=z;
	}
	for(int i=1;i<=n;i++)
		prez[i]+=prez[i-1];
	sort(syz+1,syz+1+m,cmp);
	for(auto&x:mp)//x.first是鍵,x.second是值
		reverse(x.second.begin(),x.second.end());
	for(int i=1,j=1;i<=n;i++){//i是左端點
		while(j<=m&&syz[j].x==i)
			change(1,syz[j].z),change(syz[j].y+1,-syz[j].z),j++;
		if(mp[sum[i-1]^K].size()<B)
		for(int x:mp[sum[i-1]^K]){//x是右端點
			if(x<i)break;
			int temp=ask(x);
			cc[i]+=temp;
			cc[x+1]-=temp;
		}
	}
	for(auto&x:mp)
	if(x.second.size()>=B)
		solve(x.first);
	for(int i=1;i<=n;i++)
		printf("%lld%c",(cc[i]+=cc[i-1])%=mod," \n"[i==n]);
}

現在我們考慮怎麼樣拆分一個三元組,以及如何處理錯誤統計的 \(l,r\)

對於 \((x,y,z)\) ,在處理 \(cc_l\) 時,我們不想讓 \(y<r\) 的那些區間統計上 \(z\)。我麼需要新開一個陣列,統計上需要減去的這些 \(z\)。(此時樹狀陣列的 \(c\) 陣列已經沒用了我們不如再次利用 \(c\))我們令 \(c_x+=z*cnt(r)\),這些 \(r\) 要滿足 \(r>y\)。同時令 \(c_{y+1}-=z*cnt(r)\)。這個意思是:在掃描到 \(l\ge x\)\(l\) 時,統計的答案要減去 \(c_x\),因為多出來的 \(r\) 不應該統計上去。統計到 \(y\) 後面的 \(l\) 時,不用減去這些了,因為本來就沒有統計上(不明白為什麼沒統計上的話,可以看一下\(prez\)​)。

在處理 \(cc_{r+1}\) 時,我們應當減去 \(x\le l\le y<r\) 對應的 \(z\)。我們在掃描三元組 \((x,y,z)\) 時,令 \(c[y+1]+=z*cnt(l)\),這些 \(l\) 滿足 \(x\le l \le y\)。意思是,掃描到 \(r\ge y+1\)\(r\) 時,統計的答案要少減去 \(c[y+1]\)。因為,前面對應的那些三元組,貢獻要減少 \(c[y+1]\) ,因為 \(r\) 越界了,那些 \(l\) 不會和這個 \(r\) 產生貢獻。

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int mod=1<<30;
int n,m,K,B,a[150010],sum[150010];
int sX[150010],sY[150010];
LL c[150010],cc[150010],prez[150010];
unordered_map<int,vector<int>>mp;
struct SYZ
{int x,y,z;}syz[150010];
inline int read()
{
	int x=0,w=0;char ch=0;
	while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
	while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return w?-x:x;
}
bool cmp(SYZ n1,SYZ n2)
{
	if(n1.x^n2.x)return n1.x<n2.x;
	return n1.y<n2.y;
}
void change(int x,int y)
{for(;x<=n;x+=x&-x)c[x]+=y;}
int ask(int x,int y=0)
{for(;x;x-=x&-x)y+=c[x];return y;}
void solve(int Y)//Y=sum[r]
{
	int X=K^Y;//X=sum[l-1]
	for(int i=1;i<=n;i++){
		sX[i]=sX[i-1]+(sum[i-1]==X);
		sY[i]=sY[i-1]+(sum[i]==Y);
	}
	memset(c,0,sizeof c);
	for(int i=1;i<=m;i++){
		int x=syz[i].x,y=syz[i].y,z=syz[i].z;
		c[x]+=1ll*z*(sY[n]-sY[y]);
		c[y+1]-=1ll*z*(sY[n]-sY[y]);
	}
	for(int i=1;i<=n;i++){
		c[i]+=c[i-1];
		if(sum[i-1]==X)
			cc[i]+=prez[i]*(sY[n]-sY[i-1])-c[i];
	}
	memset(c,0,sizeof c);
	for(int i=1;i<=m;i++){
		int x=syz[i].x,y=syz[i].y,z=syz[i].z;
		c[y+1]+=1ll*z*(sX[y]-sX[x-1]);
	}
	LL temp=0;
	for(int i=1;i<=n;i++){
		c[i]+=c[i-1];
		if(sum[i-1]==X)temp+=prez[i];
		if(sum[i]==Y)
			cc[i+1]-=temp-c[i];
	}
}
int main()
{
	n=read();m=read();K=read();B=sqrt(n);
	mp[0].push_back(0);
	for(int i=1;i<=n;i++)
		mp[sum[i]=sum[i-1]^(a[i]=read())].push_back(i);
	for(int i=1;i<=m;i++){
		int x=read(),y=read(),z=read();
		syz[i]=(SYZ){x,y,z};
		prez[x]+=z;prez[y+1]-=z;
	}
	for(int i=1;i<=n;i++)
		prez[i]+=prez[i-1];
	sort(syz+1,syz+1+m,cmp);
	for(auto&x:mp)//x.first是鍵,x.second是值
		reverse(x.second.begin(),x.second.end());
	for(int i=1,j=1;i<=n;i++){//i是左端點
		while(j<=m&&syz[j].x==i)
			change(1,syz[j].z),change(syz[j].y+1,-syz[j].z),j++;
		if(mp[sum[i-1]^K].size()<B)
		for(int x:mp[sum[i-1]^K]){//x是右端點
			if(x<i)break;
			int temp=ask(x);
			cc[i]+=temp;
			cc[x+1]-=temp;
		}
	}
	for(auto&x:mp)
	if(x.second.size()>=B)
		solve(x.first);
	for(int i=1;i<=n;i++)
		printf("%lld%c",(cc[i]+=cc[i-1])%=mod," \n"[i==n]);
}

相關文章