樹上遍歷:CCF 難得一遇的好題!
參考了洛谷的第一篇題解,所以思路會有點相似。
部分分
當 \(k=1\) 時,顯然方案總數為 \(\prod_{i=1}^{n}(d_i-1)!\),因為進入一個子節點後可以以任意順序遍歷它的所有出邊。
觀察
當遍歷出來的樹的形態確定時,能形成這棵樹的邊組合在一起一定是一條鏈,且這條鏈是這棵樹上兩個不同的葉子節點之間的鏈。
這個手模幾組樣例就應該理解了,證明比較感性,觀察可得一個節點的所有邊一定在遍歷出的樹中是一條鏈。所以我們一定要從一個邊以一條鏈的路徑走到其他的所有邊,於是這條鏈上中間的邊就是無法作為起點的,因為去了某個邊之後就回不來了。
每個遍歷生成樹都一定都這樣的一條鏈,且一旦確定這條鏈,那麼生成的新樹就有 \(\prod_{i=1}^{n}(d_i-1-[i\in V])!\) 種,其中 \(V\) 表示這條鏈上的節點。理解就是進入一個子節點之後,在鏈上的點已經確定了它從哪裡來,並且確定了它最後走的邊是哪條,因此是 \(d_i-2\)。
這個公式可以等價轉化為 \(\prod_{i=1}^{n}(d_i-1)! \times \prod_{i=1}^{|V|}(d_{V_i}-1)^{-1}\),於是我們就可以開始樹形 dp 了。
因為 \(\prod_{i=1}^{n}(d_i-1)!\) 的係數是所有鏈都要乘的,所以我們把它提出來放到最後再乘。
樹形 dp
顯然,現在這個問題已經被轉化為了求解帶權的每條連線兩個葉子鏈的總和是多少。
我們設計 \(dp_{i,0/1}\) 表示以節點 \(i\) 為根的子樹中目前一共有多少種合法子樹方案,且當前子樹中有沒有關鍵邊。
顯然,我們可以每次遍歷 \(i\) 節點的一個子樹之後,在他們的 LCA,即節點 \(i\) 處統計答案,即可不重不漏。
先能寫出遍歷到當前子樹的答案:
- 當連線子樹的邊是關鍵邊時,此時一定滿足統計進答案的標準,那麼 \(res=res+(dp_{v,0}+dp_{v,1})\times(dp_{i,0}+dp_{i,1})\),可以透過把他們乘開來理解這個式子。
- 當連線子樹的邊是關鍵邊時,此時不一定滿足統計進答案的標準,那麼就要讓前面遍歷的子樹或者當前子樹中至少存在一條關鍵邊,則 \(res=res+dp_{i,0}\times dp_{v,1}+dp_{i,1}\times dp_{v,1}+dp_{i,1}\times dp_{v,0}\)。
接下來考慮合併子樹的 \(dp\) 值:
- 當連線的子樹是關鍵邊時,此時以 \(i\) 為根的子樹內一定含有關鍵邊,那麼只轉移一個式子即可:
- 當連線的子樹不是關鍵邊時,就都能轉移:
注意在最後要乘上 \((d_i-1)^{-1}\),包括 \(res\) 與 \(dp\) 值,才能保證求的是這個式子。
如果當前遍歷到了葉子節點,那麼注意在 dp 完後把 \(dp_{i,0}\) 賦值為 \(1\),才能統計到答案。這也是為什麼不能從一個葉子節點開始 dfs 的原因,會漏加這個 \(1\)。
時間複雜度 \(O(Tn)\)。
程式碼
鏈式前向星記得開雙倍空間!!!
#include <bits/stdc++.h>
#define fi first
#define se second
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pi;
const ll mod=1e9+7;
const int N=200005;
ll inv[N+100];
int n,m,eu[N],ev[N];
int h[N],e[N],ne[N],idx,id[N],d[N];
ll ans=0,dp[N][2];
bitset<N>sig;
void init()
{
sig.reset();
memset(h,-1,sizeof(h));
memset(d,0,sizeof(d));
idx=0;
ans=0;
}
void add(int u,int v,int x)
{
idx++;
ne[idx]=h[u];
h[u]=idx;
e[idx]=v;
id[idx]=x;
d[u]++;
}
void dfs(int u,int fa)
{
ll res=0;
dp[u][0]=dp[u][1]=0;
for(int i=h[u];i!=-1;i=ne[i])
{
int v=e[i],x=id[i];
if(v==fa)continue;
dfs(v,u);
if(sig[x])
{
res=(res+(dp[v][1]+dp[v][0])*(dp[u][1]+dp[u][0])%mod)%mod;
dp[u][1]=(dp[u][1]+dp[v][0]+dp[v][1])%mod;
}
else
{
res=(res+dp[u][1]*dp[v][0]%mod+dp[u][1]*dp[v][1]%mod+dp[u][0]*dp[v][1]%mod)%mod;
dp[u][1]=(dp[u][1]+dp[v][1])%mod;
dp[u][0]=(dp[u][0]+dp[v][0])%mod;
}
}
if(d[u]==1)dp[u][0]++;
ans=(ans+res*inv[d[u]-1]%mod)%mod;
dp[u][0]=(dp[u][0]*inv[d[u]-1])%mod;
dp[u][1]=(dp[u][1]*inv[d[u]-1])%mod;
}
void solve()
{
scanf("%d%d",&n,&m);
init();
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&eu[i],&ev[i]);
add(eu[i],ev[i],i);
add(ev[i],eu[i],i);
}
for(int i=1;i<=m;i++)
{
int x;
scanf("%d",&x);
sig[x]=1;
}
for(int i=1;i<=n;i++)
{
if(d[i]>1)
{
dfs(i,-1);
break;
}
}
for(int i=1;i<=n;i++)
{
for(int j=1;j<d[i];j++)
{
ans=(ans*j)%mod;
}
}
printf("%lld\n",ans);
}
int main()
{
freopen("traverse.in","r",stdin);
freopen("traverse.out","w",stdout);
inv[0]=inv[1]=1;
for(int i=2;i<=N;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
int c,t;
scanf("%d%d",&c,&t);
while(t--)solve();
return 0;
}