P10717 題解

Harry27182發表於2024-07-25

好神仙的題目。賽時胡了一個狀態和轉移都和官解不同的做法,得到了 \(O(n10^m)\) 的優秀複雜度。卡了一場常卡進了 \(75\) 分。這個做法和官解關係不大,並且很難進行最後的最佳化部分,所以在此不再贅述。

首先考慮 \(k=1\) 的情況。考慮記錄一些狀態能夠描述子樹內的選擇方案,\(0\) 表示整個子樹沒有被覆蓋過,\(1\) 表示子樹內部有點被覆蓋過並且子樹外的點還能被覆蓋,\(2\) 表示子樹內部有點被覆蓋過並且子樹外的點不能被覆蓋了。考慮轉移,需要把轉移描述為只和 \(u,v\) 有關的形式才能較為簡單的擴充套件到 \(k\neq 1\) 的情況。發現對於 \(1\rightarrow 2\) 的轉移,很難描述為 \(u,v\) 的形式,因為需要出現兩個子樹為 \(1\) 或者根節點被選擇才能轉移到 \(2\)。所以考慮記錄輔助狀態 \(3\) 表示出現過至少 \(2\)\(1\) 的方案。那麼轉移有以下 \(8\) 種:

\[(0,0)\rightarrow 0 \]

\[(0,1)\rightarrow 1 \ \ (1,0)\rightarrow 1 \]

\[(0,2)\rightarrow 2 \ \ (2,0)\rightarrow 2 \]

\[(3,0)\rightarrow 3 \ \ (3,1)\rightarrow 3 \]

\[(1,1)\rightarrow 3 \]

上面沒有出現過的轉移為不合法或者不存在對應狀態。這麼轉移之後再考慮和根節點是否選擇合併的轉移,那麼有:

\[(0,0)\rightarrow 0 \ \ (1,0)\rightarrow 1 \ \ (2,0)\rightarrow 2 \ \ (3,0)\rightarrow 3 \ \ \]

\[(0,1)\rightarrow 3 \ \ (1,1)\rightarrow3 \ \ (3,1)\rightarrow 3 \ \ \]

轉移的同時計入 \(p,a\) 兩個陣列的貢獻。最後將 \(3\) 狀態放到 \(1,2\) 兩種狀態即可。因為 \(3\) 狀態對應的狀態可以封口也可以不封口。複雜度 \(O(n)\)

考慮對於 \(k\neq 1\) 的情況,每一位暴力列舉上面的 \(8\) 種轉移,第一部分的轉移複雜度是 \(O(8^k)\) 的。對於複合根節點情況的部分,暴力列舉根節點狀態顯然不優,可以類似 FMT 的對每一位依次進行變換,也就是逐位列舉根節點狀態並處理這一位變換後的位置。複雜度為 \(O(k4^k)\)。對於 \(3\) 狀態的下放可以用類似的做法也做到 \(O(k4^k)\)。複雜度 \(O(n(8^k+k4^k))\),視常數可以獲得 \(45\sim 85\) 分。

考慮最佳化,目前的瓶頸在於 \(O(8^k)\) 的部分。一個很神秘的做法是考慮到如果沒有輔助狀態 \(3\),那麼轉移只有 \(O(5^k)\)。所以考慮列舉兒子的一些位置的狀態欽定為 \(3\),由於對於 \(3\) 的轉移是和 \(0/1\) 複合之後仍然為 \(3\),所以為 \(3\) 的位可以讓它的值為對應位為 \(0/1\) 的和。類似 OR 卷積的 FWT,經過一次正變換之後為 \(3\) 的位置真實值可以為 \(0\)\(1\)。然後對變換之後的部分進行 \(O(5^k)\) 的轉移,但是多了 \(3\) 的狀態,由於經過了變換,只需要加入 \((3,3)\rightarrow 3\) 的轉移。這部分轉移的複雜度是 \(O(6^k)\) 的。對於轉移之後 \(3\) 的位置,他們是從 \((0/1,0/1)\) 轉移過來的,所以真實值可能是 \(0/1/3\),所以要進行一次類似 OR 卷積的 IFWT 讓他變成真實值為 \(3\) 的值。FWT 和 IFWT 的複雜度是 \(O(k4^k)\),所以總的複雜度就是 \(O(n(6^k+k4^k))\),可以透過。

#include<bits/stdc++.h>
using namespace std;
struct edge{int v,nxt;}e[205];
int n,m,u,v,cnt,h[105],w[105][256],p[105][8],dp[105][1<<16],num,tmp[1<<16];
void add(int u,int v){e[++cnt]={v,h[u]};h[u]=cnt;}
const int mod=998244353;
void Add(int &x,int y){x=(x+y>=mod?x+y-mod:x+y);}
struct node{int x,y,z;}go[2000005];
void init(int k,int x,int y,int z)
{
	if(k==m){go[++num]={x,y,z};return;}
	init(k+1,x,y,z);
	init(k+1,x,y|(1<<(k<<1)),z|(1<<(k<<1)));
	init(k+1,x|(1<<(k<<1)),y,z|(1<<(k<<1)));
	init(k+1,x|(2<<(k<<1)),y,z|(2<<(k<<1)));
	init(k+1,x,y|(2<<(k<<1)),z|(2<<(k<<1)));
	init(k+1,x|(3<<(k<<1)),y|(3<<(k<<1)),z|(3<<(k<<1)));
}
void fwt(int *a)
{
	for(int i=0;i<m;i++)
	{
		for(int s=0;s<(1<<(m<<1));s++)
		{
			int c=(s>>(i<<1))&3;
			if(c==0)Add(a[s+(3<<(i<<1))],a[s]);
			else if(c==1)Add(a[s+(2<<(i<<1))],a[s]); 
		}
	}
}
void ifwt(int *a)
{
	for(int i=0;i<m;i++)
	{
		for(int s=0;s<(1<<(m<<1));s++)
		{
			int c=(s>>(i<<1))&3;
			if(c==3)Add(a[s],mod-a[s-(3<<(i<<1))]),Add(a[s],mod-a[s-(2<<(i<<1))]);
		}
	}
}
void dfs(int u,int fa)
{
	dp[u][0]=1;
	for(int i=h[u];i;i=e[i].nxt)
	{
		int v=e[i].v;
		if(v==fa)continue;
		dfs(v,u);
		for(int s=0;s<(1<<(m<<1));s++)tmp[s]=dp[u][s],dp[u][s]=0;
		fwt(tmp);fwt(dp[v]);
		for(int s=1;s<=num;s++)Add(dp[u][go[s].z],1ll*tmp[go[s].x]*dp[v][go[s].y]%mod);
		ifwt(dp[u]);
	}
	for(int i=0;i<m;i++)
	{
		for(int s=0;s<(1<<(m<<1));s++)tmp[s]=dp[u][s],dp[u][s]=0;
		for(int s=0;s<(1<<(m<<1));s++)
		{
			int c=(s>>(i<<1))&3;
			if(c==0)
			{
				Add(dp[u][s],1ll*tmp[s]*(mod+1-p[u][i])%mod);
				Add(dp[u][s|(3<<(i<<1))],1ll*tmp[s]*p[u][i]%mod);
			}
			else if(c==1)
			{
				Add(dp[u][s],1ll*tmp[s]*(mod+1-p[u][i])%mod);
				Add(dp[u][s|(2<<(i<<1))],1ll*tmp[s]*p[u][i]%mod);
			}
			else if(c==2)
			{
				Add(dp[u][s],1ll*tmp[s]*(mod+1-p[u][i])%mod);
			}
			else 
			{
				Add(dp[u][s],tmp[s]);
			}
		}
	}
	for(int s=0;s<(1<<(m<<1));s++)
	{
		int ns=0;
		for(int i=0;i<m;i++)if((s>>(i<<1))&1)ns|=(1<<i);
		dp[u][s]=1ll*dp[u][s]*w[u][ns]%mod;
	}
	for(int i=0;i<m;i++)
	{
		for(int s=0;s<(1<<(m<<1));s++)tmp[s]=dp[u][s],dp[u][s]=0;
		for(int s=0;s<(1<<(m<<1));s++)
		{
			if(((s>>(i<<1))&3)==3)
			{
				Add(dp[u][s-(1<<(i<<1))],tmp[s]);
				Add(dp[u][s-(2<<(i<<1))],tmp[s]); 
			}
			else Add(dp[u][s],tmp[s]);
		}
	}
}
int main()
{
	//freopen("e.in","r",stdin);
	cin.tie(0)->sync_with_stdio(0);
	cin>>n>>m;
	for(int i=1;i<n;i++)
	{
		cin>>u>>v;
		add(u,v);add(v,u);
	}
	for(int i=0;i<m;i++)for(int j=1;j<=n;j++)cin>>p[j][i];
	for(int i=1;i<=n;i++)
	{
		for(int s=0;s<(1<<m);s++)cin>>w[i][s];
	}
	init(0,0,0,0);dfs(1,0);
	int ans=0;
	for(int s=0;s<(1<<(m<<1));s++)
	{
		int flag=1;
		for(int i=0;i<m;i++)flag&=(((s>>(i<<1))&3)!=1);
		if(flag)Add(ans,dp[1][s]);
	}
	cout<<ans;
	return 0;
}