快速沃爾什變換 (FWT)學習筆記

liuchanglc發表於2020-12-30

證明均來自xht37 的洛谷部落格

作用

\(OI\) 中,\(FWT\) 是用於解決對下標進行位運算卷積問題的方法。

\(c_{i}=\sum_{i=j \oplus k} a_{j} b_{k}\)

其中 \(\oplus\) 是二元位運算中的一種。

實現

\(or\) 運算

構造 \(fwt[a]_i = \sum_{j|i=i} a_j\)

\(\begin{aligned} fwt[a] \times fwt[b] &= \left(\sum_{j|i=i} a_j\right)\left(\sum_{k|i=i} b_k\right) \\\\ &= \sum_{j|i=i} \sum_{k|i=i} a_jb_k \\\\ &= \sum_{(j|k)|i = i} a_jb_k \\\\ &= fwt[c] \end{aligned}\)

\([a]\)\(fwt[a]\) 可以分治解決

我們從小到大依次列舉長度為 \(2^i\) 的區間

設最高位為第 \(i\)

此時我們已經求出了前 \(i-1\) 位的貢獻

並且區間的左半部分最高位上的數字為 \(0\)

區間的右半部分最高位上的數字為 \(1\)

對於左邊的這些數,右邊的數顯然不可能是左邊的數的子集

只能由自己 \(i-1\) 位的貢獻轉移過來

但是左邊的數會給相應位置的右邊的數做出貢獻

因此我們在進行變換的時候要把這個貢獻加上

同樣在進行逆變換的時候相應地減去即可

程式碼

void fwtor(rg int A[],rg int typ){
	for(rg int k=1,o=2;o<=mmax;k<<=1,o<<=1){
		for(rg int j=0;j<mmax;j+=o){
			for(rg int i=0;i<k;i++){
				A[i+j+k]+=typ*A[i+j];
				A[i+j+k]=getmod(A[i+j+k]);
			}
		}
	}
}

\(and\) 運算

\(or\) 運算基本一樣,只是這次變成了右區間對左區間有貢獻

程式碼

void fwtand(rg int A[],rg int typ){
	for(rg int k=1,o=2;o<=mmax;k<<=1,o<<=1){
		for(rg int j=0;j<mmax;j+=o){
			for(rg int i=0;i<k;i++){
				A[i+j]+=typ*A[i+j+k];
				A[i+j]=getmod(A[i+j]);
			}
		}
	}
}

\(xor\) 運算

這種運算比較複雜,因為不再是簡單的子集的關係了

但是我們仍然可以用以上兩種運算的思想

定義 \(x\otimes y=\text{popcount}(x \& y) \bmod 2\)

其中 \(\text{popcount}\) 表示「二進位制下 \(1\) 的個數」

滿足 \((i \otimes j) \operatorname{xor} (i \otimes k) = i \otimes (j \operatorname{xor} k)\)

構造 \(fwt[a]_i = \sum_{i\otimes j = 0} a_j - \sum_{i\otimes j = 1} a_j\)

則有

\(\begin{aligned} fwt[a] \times fwt[b] &= \left(\sum_{i\otimes j = 0} a_j - \sum_{i\otimes j = 1} a_j\right)\left(\sum_{i\otimes k = 0} b_k - \sum_{i\otimes k = 1} b_k\right) \\ &=\left(\sum_{i\otimes j=0}a_j\right)\left(\sum_{i\otimes k=0}b_k\right)-\left(\sum_{i\otimes j=0}a_j\right)\left(\sum_{i\otimes k=1}b_k\right)-\left(\sum_{i\otimes j=1}a_j\right)\left(\sum_{i\otimes k=0}b_k\right)+\left(\sum_{i\otimes j=1}a_j\right)\left(\sum_{i\otimes k=1}b_k\right) \\ &=\sum_{i\otimes(j \operatorname{xor} k)=0}a_jb_k-\sum_{i\otimes(j\operatorname{xor} k)=1}a_jb_k \\ &= fwt[c] \end{aligned} \)

當最高位是 \(0\) 時,因為 \(0\&1=0\)\(0\&0=0\)

二進位制下 \(1\) 的個數不變

所以左邊區間的價值應為只考慮前 \(i-1\) 位時左邊區間的價值加上只考慮前 \(i-1\) 位時右邊區間的價值

而對於右邊區間,當 \(1\&1=1\) 時,二進位制下一的個數會發生變化

所以應該是隻考慮前 \(i-1\) 位時左邊區間的價值減去只考慮前 \(i-1\) 位時右邊區間的價值

逆變換就是反這來,乘上 \(\frac{1}{2}\) 即可

程式碼

void fwtxor(rg int A[],rg int typ){
	for(rg int k=1,o=2;o<=mmax;k<<=1,o<<=1){
		for(rg int j=0;j<mmax;j+=o){
			for(rg int i=0;i<k;i++){
				rg int x=1LL*A[i+j]*typ%mod,y=1LL*A[i+j+k]*typ%mod;
				A[i+j]=getmod(x+y);
				A[i+j+k]=getmod(x-y);
			}
		}
	}
}

題目

P5366 [SNOI2017]遺失的答案

題目傳送門

分析

先特判掉 \(G\) 不能整除 \(L\) 的情況

然後把 \(L\)\(n\) 同時除以 \(G\)

這樣問題就轉化為了在 \(1\)\(n\) 中選擇一些數

使得他們的最大公因數為 \(1\),最小公倍數為 \(L\)

\(L\) 進行質因數分解,設 \(L=p_1^{a_1}p_2^{a_2}...p_n^{a_n}\)

如果要滿足條件

那麼對於任意一個質因數 \(p_i\) ,選擇的數中必須至少存在一個數,使得它分解質因數後 \(p_i\) 的指數等於 \(a_i\)

同理,對於任意一個質因數 \(p_i\) ,選擇的數中必須至少存在一個數,不含有 \(p_i\) 這個質因數

第一個條件可以看做是否滿足上界,第二個條件可以看作是否滿足下界

因為 \(L\) 小於等於 \(10^{8}\),所以最多含有 \(8\) 個不同的質因數

因此可以狀壓

\(11\) 表示同時滿足上界和下界,\(10\) 表示只滿足上界,\(01\) 表示只滿足下界,\(00\) 表示上界和下界都不滿足

顯然滿足條件的只能是 \(L\) 的因數,我們可以把 \(L\) 的所有因數都篩出來

然後求出這些因數所代表的狀態

因數不會太多,最多隻有 \(768\)

如果沒有必須選擇 \(x\) 的限制,那麼直接設 \(f[i][j]\) 表示考慮前 \(i\) 個數,狀態為 \(j\) 的方案數進行 \(dp\) 即可

如果考慮 \(x\) 的限制,我們就需要維護一個字首 \(dp\) 陣列 \(pre\) 和字尾 \(dp\) 陣列 \(suf\)

對於第 \(i\) 個數,我們把 \(pre[i-1]\)\(suf[i+1]\) 進行或運算卷積

最後只要第 \(i\) 個數的狀態與某個狀態進行或運算等於全集

那麼就可以累加這個狀態的答案

程式碼

#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define rg register
inline int read(){
	rg int x=0,fh=1;
	rg char ch=getchar();
	while(ch<'0' || ch>'9'){
		if(ch=='-') fh=-1;
		ch=getchar();
	}
	while(ch>='0' && ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*fh;
}
const int mod=1e9+7,maxn=70005,maxm=1005;
int n,g,l,q,x,mmax;
int getmod(rg int now){
	return now>=mod?now-mod:now<0?now+mod:now;
}
void fwtor(rg int A[],rg int typ){
	for(rg int o=2,k=1;o<=mmax;o<<=1,k<<=1){
		for(rg int i=0;i<mmax;i+=o){
			for(rg int j=0;j<k;j++){
				A[i+j+k]+=A[i+j]*typ;
				A[i+j+k]=getmod(A[i+j+k]);
			}
		}
	}
}
int pri[maxn],mi[maxn];
void divid(rg int now){
	rg int m=sqrt(now),ncnt=0;
	for(rg int i=2;i<=m;i++){
		if(now%i==0){
			ncnt=0;
			pri[++pri[0]]=i;
			while(now%i==0){
				now/=i;
				ncnt++;
			}
			mi[pri[0]]=ncnt;
		}
	}
	if(now>1){
		pri[++pri[0]]=now;
		mi[pri[0]]=1;
	}
}
int sta[maxn],tp,zt[maxn];
void getit(){
	rg int m=sqrt(l);
	for(rg int i=1;i<=m;i++){
		if(l%i==0){
			if(i<=n) sta[++tp]=i;
			if(i*i!=l && l/i<=n) sta[++tp]=l/i;
		}
	}
}
int pre[maxm][maxn],suf[maxm][maxn],tmp[maxn],ans[maxn];
int getzt(rg int now){
	rg int zt0=0,zt1=0;
	for(rg int i=1;i<=pri[0];i++){
		rg int ncnt=0;
		while(now%pri[i]==0){
			now/=pri[i];
			ncnt++;
		}
		if(ncnt==0) zt0|=(1<<(i-1));
		else if(ncnt==mi[i]) zt1|=(1<<(i-1));
	}
	return zt0|(zt1<<pri[0]);
}
int main(){
	n=read(),g=read(),l=read(),q=read();
	if(l%g){
		for(rg int i=1;i<=q;i++){
			x=read();
			printf("0\n");
		}
	} else {
		l/=g,n/=g;
		divid(l);
		mmax=1<<(2*pri[0]);
		getit();
		std::sort(sta+1,sta+1+tp);
		for(rg int i=1;i<=tp;i++) zt[i]=getzt(sta[i]);
		pre[0][0]=suf[tp+1][0]=1;
		for(rg int i=1;i<=tp;i++){
			memcpy(pre[i],pre[i-1],sizeof(pre[i-1]));
			for(rg int j=0;j<mmax;j++){
				pre[i][j|zt[i]]=getmod(pre[i][j|zt[i]]+pre[i-1][j]);
			}
		}
		for(rg int i=tp;i>=1;i--){
			memcpy(suf[i],suf[i+1],sizeof(suf[i+1]));
			for(rg int j=0;j<mmax;j++){
				suf[i][j|zt[i]]=getmod(suf[i][j|zt[i]]+suf[i+1][j]);
			}
		}
		for(rg int i=0;i<=tp+1;i++){
			fwtor(pre[i],1);
			fwtor(suf[i],1);
		}
		for(rg int i=1;i<=tp;i++){
			for(rg int j=0;j<mmax;j++){
				tmp[j]=1LL*pre[i-1][j]*suf[i+1][j]%mod;
			}
			fwtor(tmp,-1);
			for(rg int j=0;j<mmax;j++){
				if((zt[i]|j)==mmax-1) ans[i]=getmod(ans[i]+tmp[j]);
			}
		}
		for(rg int i=1;i<=q;i++){
			x=read();
			if(x%g) printf("0\n");
			else {
				x/=g;
				if(l%x) printf("0\n");
				else {
					rg int wz=std::lower_bound(sta+1,sta+1+tp,x)-sta;
					printf("%d\n",ans[wz]);
				}
			}
		}
	}
	return 0;
}

P3175 [HAOI2015]按位或

題目傳送門

分析

要用到 \(min\)\(max\) 容斥

不會的可以看我的容斥原理學習筆記

$max(S)=\sum_{T\subseteq S}(-1)^{|T|+1}min(T) $

$min(S)=\sum_{T\subseteq S}(-1)^{|T|+1}max(T) $

\(max(S)\)\(S\) 中最晚的元素出現的期望次數

\(min(S)\)\(S\) 中最早的元素出現的期望次數

問題轉換為如何求 \(min(T)\)

\(P=\sum\limits_{S\subseteq\complement_UT}P(S)\)

\(E(\min(T))=P\sum\limits^{+\infty}_{i=1}i(1-p)^{i-1}\)

有邊是一個等比數列乘等差數列的求和公式

化簡之後是

\(\frac{1-(1-P)^n}{P^2}-\frac{n(1-P)^n}{P}\)

\(n\) 趨進於無窮大時

\((1-P)^n\) 趨進於 \(0\)

因此最終的結果是 \(\frac{1}{P^2}\)

再乘上外面的一個 \(P\),就是 \(\frac{1}{P}\)

剩下的再用一個或運算卷積即可

程式碼

#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#define rg register
inline int read(){
	rg int x=0,fh=1;
	rg char ch=getchar();
	while(ch<'0' || ch>'9'){
		if(ch=='-') fh=-1;
		ch=getchar();
	}
	while(ch>='0' && ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*fh;
}
const int maxn=2e6+5,mod=998244353;
const double eps=1e-10;
int n,mmax,siz[maxn];
double a[maxn];
void fwtor(rg double A[],rg int typ){
	for(rg int i=1;i<=n;i++){
		for(rg int j=0;j<mmax;j+=(1<<i)){
			for(rg int k=0;k<1<<(i-1);k++){
				A[j|(1<<(i-1))|k]+=A[j|k]*typ;
			}
		}
	}
}
int main(){
	n=read();
	mmax=1<<n;
	for(rg int i=0;i<mmax;i++){
		scanf("%lf",&a[i]);
	}
	fwtor(a,1);
	for(rg int i=0;i<mmax;i++){
		siz[i]=siz[i>>1]+(i&1);
	}
	rg double nans=0;
	for(rg int i=1;i<mmax;i++){
		if(1.0-a[(mmax-1)^i]<eps){
			printf("INF\n");
			return 0;
		}
		nans+=((siz[i]&1)?(1.0):(-1.0))/(1.0-a[(mmax-1)^i]);
	}
	printf("%.8f\n",nans);
	return 0;
}

P5643 [PKUWC2018]隨機遊走

題目傳送門

分析

同樣是 \(min\)-\(max\) 容斥,先求出至少經過一個點的期望步數

然後再求出全部經過的期望步數

$max(S)=\sum_{T\subseteq S}(-1)^{|T|+1}min(T) $

\(f_i\) 表示從 \(i\) 出發,經過 \(S\) 中的至少一個點的期望步數

\(deg_i\) 為點 \(i\) 的度數,\(j\)\(i\) 的兒子節點

可以得到這樣的遞推式:

\(f_i=\frac1{deg_i}(f_{fa_i}+\sum f_j)+1\)

\(f_i=k_if_{fa_i}+b_i\)

化簡之後可以得到

$f_i=\frac1{deg_i-\sum k_j}f_{fa_i}+\frac{deg_i+\sum b_j}{deg_i-\sum k_j} $

$k_i=\frac 1{deg_i-\sum k_j},b_i=\frac{deg_i+\sum b_j}{deg_i-\sum k_j} $

這個東西可以 \(dfs\) 求出來

然後就可以用或運算卷積合併預處理出每一個集合的答案

程式碼

#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstring>
#define rg register
inline int read(){
	rg int x=0,fh=1;
	rg char ch=getchar();
	while(ch<'0' || ch>'9'){
		if(ch=='-') fh=-1;
		ch=getchar();
	}
	while(ch>='0' && ch<='9'){
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*fh;
}
const int maxn=1e6+5,maxm=25,mod=998244353;
int n,q,x,mmax,h[maxm],k[maxm],a[maxm],tot=1,du[maxm],ans[maxn],siz[maxn];
int ksm(rg int ds,rg int zs){
	rg int nans=1;
	while(zs){
		if(zs&1) nans=1LL*nans*ds%mod;
		ds=1LL*ds*ds%mod;
		zs>>=1;
	}
	return nans;
}
struct asd{
	int to,nxt;
}b[maxm<<1];
void ad(rg int aa,rg int bb){
	b[tot].to=bb;
	b[tot].nxt=h[aa];
	h[aa]=tot++;
}
int getmod(rg int now){
	return (now>=mod)?(now-mod):((now<0)?(now+mod):now);
}
void fwtor(rg int A[],rg int typ){
	for(rg int o=2,k=1;o<=mmax;o<<=1,k<<=1){
		for(rg int i=0;i<mmax;i+=o){
			for(rg int j=0;j<k;j++){
				A[i+j+k]+=A[i+j]*typ;
				A[i+j+k]=getmod(A[i+j+k]);
			}
		}
	}
}
void dfs(rg int now,rg int lat,rg int zt){
	if(zt&(1<<(now-1))) return;
	rg int ans1=0,ans2=0;
	for(rg int i=h[now];i!=-1;i=b[i].nxt){
		rg int u=b[i].to;
		if(u==lat) continue;
		dfs(u,now,zt);
		ans1+=k[u];
		ans2+=a[u];
		ans1=getmod(ans1);
		ans2=getmod(ans2);
	}
	k[now]=ksm(getmod(du[now]-ans1),mod-2);
	a[now]=1LL*k[now]*getmod(du[now]+ans2)%mod;
}
int main(){
	memset(h,-1,sizeof(h));
	n=read(),q=read(),x=read();
	rg int aa,bb,cc;
	for(rg int i=1;i<n;i++){
		aa=read(),bb=read();
		ad(aa,bb);
		ad(bb,aa);
		du[aa]++,du[bb]++;
	}
	mmax=1<<n;
	for(rg int i=0;i<mmax;i++) siz[i]=siz[i>>1]+(i&1);
	for(rg int i=0;i<mmax;i++){
		memset(k,0,sizeof(k));
		memset(a,0,sizeof(a));
		dfs(x,0,i);
		ans[i]=a[x]*((siz[i]&1)?1:(-1));
		ans[i]=getmod(ans[i]);
	}
	fwtor(ans,1);
	for(rg int i=1;i<=q;i++){
		aa=read();
		cc=0;
		for(rg int j=1;j<=aa;j++){
			bb=read();
			cc|=(1<<(bb-1));
		}
		printf("%d\n",ans[cc]);
	}
	return 0;
}

相關文章