好神仙的題目。賽時胡了一個狀態和轉移都和官解不同的做法,得到了 \(O(n10^m)\) 的優秀複雜度。卡了一場常卡進了 \(75\) 分。這個做法和官解關係不大,並且很難進行最後的最佳化部分,所以在此不再贅述。
首先考慮 \(k=1\) 的情況。考慮記錄一些狀態能夠描述子樹內的選擇方案,\(0\) 表示整個子樹沒有被覆蓋過,\(1\) 表示子樹內部有點被覆蓋過並且子樹外的點還能被覆蓋,\(2\) 表示子樹內部有點被覆蓋過並且子樹外的點不能被覆蓋了。考慮轉移,需要把轉移描述為只和 \(u,v\) 有關的形式才能較為簡單的擴充套件到 \(k\neq 1\) 的情況。發現對於 \(1\rightarrow 2\) 的轉移,很難描述為 \(u,v\) 的形式,因為需要出現兩個子樹為 \(1\) 或者根節點被選擇才能轉移到 \(2\)。所以考慮記錄輔助狀態 \(3\) 表示出現過至少 \(2\) 次 \(1\) 的方案。那麼轉移有以下 \(8\) 種:
上面沒有出現過的轉移為不合法或者不存在對應狀態。這麼轉移之後再考慮和根節點是否選擇合併的轉移,那麼有:
轉移的同時計入 \(p,a\) 兩個陣列的貢獻。最後將 \(3\) 狀態放到 \(1,2\) 兩種狀態即可。因為 \(3\) 狀態對應的狀態可以封口也可以不封口。複雜度 \(O(n)\)。
考慮對於 \(k\neq 1\) 的情況,每一位暴力列舉上面的 \(8\) 種轉移,第一部分的轉移複雜度是 \(O(8^k)\) 的。對於複合根節點情況的部分,暴力列舉根節點狀態顯然不優,可以類似 FMT 的對每一位依次進行變換,也就是逐位列舉根節點狀態並處理這一位變換後的位置。複雜度為 \(O(k4^k)\)。對於 \(3\) 狀態的下放可以用類似的做法也做到 \(O(k4^k)\)。複雜度 \(O(n(8^k+k4^k))\),視常數可以獲得 \(45\sim 85\) 分。
考慮最佳化,目前的瓶頸在於 \(O(8^k)\) 的部分。一個很神秘的做法是考慮到如果沒有輔助狀態 \(3\),那麼轉移只有 \(O(5^k)\)。所以考慮列舉兒子的一些位置的狀態欽定為 \(3\),由於對於 \(3\) 的轉移是和 \(0/1\) 複合之後仍然為 \(3\),所以為 \(3\) 的位可以讓它的值為對應位為 \(0/1\) 的和。類似 OR 卷積的 FWT,經過一次正變換之後為 \(3\) 的位置真實值可以為 \(0\) 或 \(1\)。然後對變換之後的部分進行 \(O(5^k)\) 的轉移,但是多了 \(3\) 的狀態,由於經過了變換,只需要加入 \((3,3)\rightarrow 3\) 的轉移。這部分轉移的複雜度是 \(O(6^k)\) 的。對於轉移之後 \(3\) 的位置,他們是從 \((0/1,0/1)\) 轉移過來的,所以真實值可能是 \(0/1/3\),所以要進行一次類似 OR 卷積的 IFWT 讓他變成真實值為 \(3\) 的值。FWT 和 IFWT 的複雜度是 \(O(k4^k)\),所以總的複雜度就是 \(O(n(6^k+k4^k))\),可以透過。
#include<bits/stdc++.h>
using namespace std;
struct edge{int v,nxt;}e[205];
int n,m,u,v,cnt,h[105],w[105][256],p[105][8],dp[105][1<<16],num,tmp[1<<16];
void add(int u,int v){e[++cnt]={v,h[u]};h[u]=cnt;}
const int mod=998244353;
void Add(int &x,int y){x=(x+y>=mod?x+y-mod:x+y);}
struct node{int x,y,z;}go[2000005];
void init(int k,int x,int y,int z)
{
if(k==m){go[++num]={x,y,z};return;}
init(k+1,x,y,z);
init(k+1,x,y|(1<<(k<<1)),z|(1<<(k<<1)));
init(k+1,x|(1<<(k<<1)),y,z|(1<<(k<<1)));
init(k+1,x|(2<<(k<<1)),y,z|(2<<(k<<1)));
init(k+1,x,y|(2<<(k<<1)),z|(2<<(k<<1)));
init(k+1,x|(3<<(k<<1)),y|(3<<(k<<1)),z|(3<<(k<<1)));
}
void fwt(int *a)
{
for(int i=0;i<m;i++)
{
for(int s=0;s<(1<<(m<<1));s++)
{
int c=(s>>(i<<1))&3;
if(c==0)Add(a[s+(3<<(i<<1))],a[s]);
else if(c==1)Add(a[s+(2<<(i<<1))],a[s]);
}
}
}
void ifwt(int *a)
{
for(int i=0;i<m;i++)
{
for(int s=0;s<(1<<(m<<1));s++)
{
int c=(s>>(i<<1))&3;
if(c==3)Add(a[s],mod-a[s-(3<<(i<<1))]),Add(a[s],mod-a[s-(2<<(i<<1))]);
}
}
}
void dfs(int u,int fa)
{
dp[u][0]=1;
for(int i=h[u];i;i=e[i].nxt)
{
int v=e[i].v;
if(v==fa)continue;
dfs(v,u);
for(int s=0;s<(1<<(m<<1));s++)tmp[s]=dp[u][s],dp[u][s]=0;
fwt(tmp);fwt(dp[v]);
for(int s=1;s<=num;s++)Add(dp[u][go[s].z],1ll*tmp[go[s].x]*dp[v][go[s].y]%mod);
ifwt(dp[u]);
}
for(int i=0;i<m;i++)
{
for(int s=0;s<(1<<(m<<1));s++)tmp[s]=dp[u][s],dp[u][s]=0;
for(int s=0;s<(1<<(m<<1));s++)
{
int c=(s>>(i<<1))&3;
if(c==0)
{
Add(dp[u][s],1ll*tmp[s]*(mod+1-p[u][i])%mod);
Add(dp[u][s|(3<<(i<<1))],1ll*tmp[s]*p[u][i]%mod);
}
else if(c==1)
{
Add(dp[u][s],1ll*tmp[s]*(mod+1-p[u][i])%mod);
Add(dp[u][s|(2<<(i<<1))],1ll*tmp[s]*p[u][i]%mod);
}
else if(c==2)
{
Add(dp[u][s],1ll*tmp[s]*(mod+1-p[u][i])%mod);
}
else
{
Add(dp[u][s],tmp[s]);
}
}
}
for(int s=0;s<(1<<(m<<1));s++)
{
int ns=0;
for(int i=0;i<m;i++)if((s>>(i<<1))&1)ns|=(1<<i);
dp[u][s]=1ll*dp[u][s]*w[u][ns]%mod;
}
for(int i=0;i<m;i++)
{
for(int s=0;s<(1<<(m<<1));s++)tmp[s]=dp[u][s],dp[u][s]=0;
for(int s=0;s<(1<<(m<<1));s++)
{
if(((s>>(i<<1))&3)==3)
{
Add(dp[u][s-(1<<(i<<1))],tmp[s]);
Add(dp[u][s-(2<<(i<<1))],tmp[s]);
}
else Add(dp[u][s],tmp[s]);
}
}
}
int main()
{
//freopen("e.in","r",stdin);
cin.tie(0)->sync_with_stdio(0);
cin>>n>>m;
for(int i=1;i<n;i++)
{
cin>>u>>v;
add(u,v);add(v,u);
}
for(int i=0;i<m;i++)for(int j=1;j<=n;j++)cin>>p[j][i];
for(int i=1;i<=n;i++)
{
for(int s=0;s<(1<<m);s++)cin>>w[i][s];
}
init(0,0,0,0);dfs(1,0);
int ans=0;
for(int s=0;s<(1<<(m<<1));s++)
{
int flag=1;
for(int i=0;i<m;i++)flag&=(((s>>(i<<1))&3)!=1);
if(flag)Add(ans,dp[1][s]);
}
cout<<ans;
return 0;
}