7.7 T3
statement
有一棵樹,有點權 \(v_x\) ,每次可以縮起來一條邊,新點的點權為 \(v_xv_y+1\),求所有操作順序最後得到點的點權和。
\(n\leq 2000\)
solution
考慮組合意義,點權相當於可以選擇相乘或者選擇變成 \(1\),考慮我們用若干連通塊表示,如果選擇相乘就不合並連通塊,否則就合併連通塊,那麼最終大小 \(=1\) 的連通塊貢獻就是 \(v_x\),\(>1\) 的連通塊貢獻就是 \(1\) ,乘起來就是答案。
那麼直接對連通塊劃分計數,我們需要從每個連通塊內選擇出一條邊作為最後一條邊,然後要求其他邊的合併順序在這條邊之前,對於連線不同連通塊的邊,要求連線的兩個連通塊的最後一條邊的順序在這條邊之前,那麼我們把這些順序要求連邊,就等價於計數拓撲序個數。
注意到連出的邊忽略掉方向後變成了一棵樹,這個經典做法是:把向上的邊容斥成向下的,然後變成若干森林再進行外向樹的拓撲序計數,在本題中需要記錄當前樹的子樹大小以及當前連通塊的大小,複雜度 \(O(n^3)\),鏈上的時候可以把每種長度的連通塊預處理一下做到 \(O(n^2)\)。
然後好像做不到 \(O(n^2)\),因此我們要換個做法計算拓撲序個數。
我們把每個連通塊的最後一條邊稱作特殊邊。
考慮直接設 \(f_{i,j}\) 表示只考慮子樹內的邊,當前連通塊已經有了特殊邊,且特殊邊在這些邊裡排名為 \(j\) 的方案數,\(g_{i,j}\) 表示只考慮子樹內的邊,當前連通塊還沒有特殊邊,且假設特殊邊在這些邊裡排名為 \(j\) 的方案數(相當於提前把特殊邊的位置留出來)。
這樣設狀態的好處是 \(f\) 和 \(g\) 合併的時候只需要把特殊邊前後分別合併就好,複雜度就是對了,而合併不同連通塊的時候可以預處理一下字首和,總的來說可以透過一些分類討論做到 \(O(n^2)\)。
#include<bits/stdc++.h>
using namespace std;
const int N = 5020;
template <typename T>inline void read(T &x)
{
x=0;char c=getchar();bool f=0;
for(;c<'0'||c>'9';c=getchar())f|=(c=='-');
for(;c>='0'&&c<='9';c=getchar())x=(x<<1)+(x<<3)+(c-'0');
x=(f?-x:x);
}
typedef long long LL;
int n,W[N];
const int mod = 998244353;
inline int sub(int a,int b){return a-b<0?a-b+mod:a-b;}
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
inline int mul(int a,int b){return 1ll*a*b%mod;}
vector<int> T[N];
int C[N][N];
void init(int n)
{
for(int i=0;i<=n;i++)
{
C[i][0]=1;
for(int j=1;j<=i;j++)
C[i][j]=add(C[i-1][j-1],C[i-1][j]);
}
}
int f[N][N],g[N][N],h[N],siz[N],F[N],G[N];
inline int mk(int i,int j)
{
if(i<0||j<0)return 0;
return C[i+j][i];
}
int B[N];
inline int val(int a,int b,int c,int d){return mul(mk(c,d),mk(a-c,b-d));}
void dfs(int x,int pre)
{
siz[x]=1;
for(int y:T[x])if(y^pre)
{
dfs(y,x);
siz[x]+=siz[y];
}
h[x]=W[x];
int c=0;
g[x][0]=1;
for(int y:T[x])if(y^pre)
{
int H=h[x];h[x]=0;
for(int i=0;i<=siz[y];i++)B[i]=0;
int s=0;
for(int j=siz[y]-1;j>=0;j--)s=add(s,f[y][j]),B[j]=s;
h[x]=add(h[x],mul(mul(H,mul(h[y],siz[y])),mk(c,siz[y])));
for(int i=0;i<siz[y];i++)if(f[y][i])
h[x]=add(h[x],mul(mul(H,mul(f[y][i],i+1)),mk(c,siz[y])));
for(int i=0;i<=c;i++)
{
G[i]=g[x][i];F[i]=f[x][i];
g[x][i]=f[x][i]=0;
}
for(int i=0;i<=c;i++)
{
if(G[i])
{
for(int j=0;j<=siz[y]-1;j++)//G*H->G
g[x][i+j+1]=add(g[x][i+j+1],mul(mul(G[i],mul(i+1,h[y])),val(c+1,siz[y]-1,i+1,j)));
s=0;
for(int j=1;j<=siz[y];j++)//G*F->G
{
s=add(s,B[j-1]);
g[x][i+j]=add(g[x][i+j],mul(mul(G[i],s),val(c,siz[y],i,j)));
}
for(int j=0;j<=siz[y]-1;j++)//G*F->F
{
f[x][i+j]=add(f[x][i+j],mul(mul(G[i],f[y][j]),mul(val(c,siz[y]-1,i,j),siz[y]-1-j-1+1)));
g[x][i+j]=add(g[x][i+j],mul(mul(G[i],g[y][j]),mul(val(c,siz[y],i,j),siz[y]-1-j+1)));
f[x][i+j]=add(f[x][i+j],mul(mul(G[i],g[y][j]),val(c,siz[y]-1,i,j)));
}
}
if(F[i])
{
for(int j=0;j<=siz[y]-1;j++)//F*H->F
f[x][i+j+1]=add(f[x][i+j+1],mul(mul(F[i],mul(i+1,h[y])),val(c-1+1,siz[y]-1,i+1,j)));
s=0;
for(int j=1;j<=siz[y];j++)//F*F->F
{
s=add(s,B[j-1]);
f[x][i+j]=add(f[x][i+j],mul(mul(F[i],s),val(c-1,siz[y],i,j)));
}
for(int j=0;j<=siz[y]-1;j++)//F*G->F
{
f[x][i+j]=add(f[x][i+j],mul(mul(F[i],g[y][j]),mul(val(c-1,siz[y],i,j),siz[y]-1-j+1)));
}
}
}
c+=siz[y];
}
}
int main()
{
read(n);
for(int i=1;i<=n;i++)read(W[i]);
for(int i=1;i<n;i++)
{
int x,y;
read(x);read(y);
T[x].push_back(y);
T[y].push_back(x);
}
init(n);
dfs(1,0);
int ans=h[1];
for(int i=0;i<=n;i++)ans=add(ans,f[1][i]);
cout<<ans;
return 0;
}