題意簡述
給定一棵 \(n\) 個節點的樹。一次操作為斷邊再連邊,需保證仍為樹。求 \(k\) 次操作後樹的形態個數。
\(n \leq 5000\),\(0 \leq k \leq n\)。
被繞暈了,對我來說是一道超好的題,寫一篇題解理清思路,也希望幫助到你。
題目分析
由於我們總能浪費步數,即斷開連上同一條邊,或連上斷開同一條邊,所以我們求的答案為 \(0 \sim k\) 次不浪費的操作後,樹的形態個數。
考慮操作了 \(k\) 次,我們就斷開了 \(k\) 條邊,並且這些邊不會被連上。如果不考慮後來連上的 \(k\) 條邊,此時樹被劃分為了 \(k + 1\) 個聯通塊。答案就是用 \(k\) 條邊,不連上斷開的邊,把 \(k + 1\) 個聯通塊連成一棵樹的方案數。
對於不會浪費步數的證明:
假設斷開了一條連上的邊,那麼剩餘邊數不足以將其連成一棵樹,故每次連邊都不會浪費。
對於 \(k\) 個操作一起考慮的正確性證明:
即需要證明,存在一種方案,過程中始終合法,斷邊之後再連邊,仍為一棵樹。
由於不能斷開一條新連上的邊,所以所有新連上的邊會被儲存到最後,即每次連上的邊只能答案中選取。
考慮某一次斷邊,在這次斷邊之前的操作均合法。斷邊後會得到兩個聯通塊,需要在最終答案中的 \(k\) 條新增的邊中,使用一條未被使用的邊。最終答案裡肯定有聯通這兩個聯通塊的邊,否則不是一棵樹,矛盾。而在斷邊之前我們保證這是一棵樹,所以這些邊都是沒有被使用的。所以我們一定能夠選擇一條邊來連上。
我們只需要考慮的浪費就是連上一條斷開的邊。如果沒有這個限制,答案怎麼統計呢?
根據 Prufer 序列經典結論,用 \(k\) 條邊把大小分別為 \(a_1, \ldots, a_{k + 1}\) 的聯通塊連成樹的方案數為 \(n ^ {(k + 1) - 2} \prod \limits _ {i = 1} ^ {k + 1} a_i\)。前者對於每一種方案都是相同的,也很好求。重點在於求出每一種方案下,後者之和,即求 \(\sum \prod \limits _ {i = 1} ^ {k + 1} a_i\)。
考慮 \(\prod a_i\) 的意義,即為從每一個聯通塊中,選出一個關鍵點,求方案數。同時為了 \(\sum\),我們還要統計所有情況下的方案數之和。不妨記 \(F[u][x][0 / 1]\) 表示以 \(u\) 為根的子樹中,得到 \(x\) 個聯通塊,並且根節點所在聯通塊有沒有選出關鍵點,所有情況下的 \(\prod a_i\) 之和。邊界 \(F[i][1][0 / 1] = 1\)。
考慮樹形 DP 孩子 \(u\) 合併到 \(v\)。決策是否斷邊,分類討論一下。
-
斷開 \(u\) 和 \(v\) 之間的邊。
即此時 \(u\) 和 \(v\) 不在同一聯通塊中,那麼 \(u\) 中關鍵點必須選出。\[F[v][i + j][0 / 1] \gets F[v][i + j][0 / 1] + F[v][i][0 / 1] \times F[u][j][1] \]之所以兩個 \(\sum\prod\) 直接相乘,是因為 \(u\) 中每一個 \(\prod\) 需要分別和 \(v\) 中的 \(\prod\) 相乘產生一種方案,最後再求和,兩個 \(\sum \prod\) 直接相乘就滿足了。
-
不斷開 \(u\) 和 \(v\) 之間的邊。
類似地得到如下轉移方程。\[F[v][i + j][0] \gets F[v][i + j][0] + F[v][i][0] \times F[u][j][0] \]\[F[v][i + j][1] \gets F[v][i + j][1] + F[v][i][0] \times F[u][j][1] + F[v][i][1] \times F[u][j][0] \]
我們所求的 \(\sum \prod a_i\) 便是 \(F[1][k + 1][1]\)。也就是用 \(F[1][k + 1][0] \times n ^ {(k + 1) - 2}\) 表示最多改變了 \(k\) 條邊的樹的個數。我們為了不浪費步數,即不能連回去,要求的是恰好改變 \(k\) 條邊的樹的個數。用至多算恰好,很容易想到二項式反演,可以看我的《學習筆記》。
為了變成更熟悉的“至少”模型,我們反轉,最多改變 \(k\) 條邊,變成至少保證 \(n - 1 - k\) 條邊不變。設 \(f(n - 1 - x) = dp[1][x + 1][1] \times n ^ {(x + 1) - 2}\),即 \(f(x) = dp[1][n - x][1] \times n ^ {(n - x) - 2}\) 表示至少 \(x\) 條邊不變。
考慮怎麼表示出 \(g(x)\) 表示恰好 \(x\) 條邊不變。\(g(x)\) 在 \(f(x)\) 的基礎上,需要去掉大於 \(x\) 條邊不變的情況。對於一個 \(t > x\),\(g(t)\) 反覆貢獻了 \(\dbinom{t}{x}\) 次。
移項得:
反演得到:
對於操作了 \(k\) 次,即為保證 \(n - 1 - k\) 條邊不變,答案為 \(g(n - 1 - k)\)。總答案即為 \(\sum \limits _ {i = 0} ^ k g(n - 1 - i)\)。
時間複雜度:\(\Theta(n ^ 2)\),瓶頸在於樹形 DP 和反演。
樹形 DP 時間複雜度 \(\Theta(n ^ 2)\) 說明:
考慮一對 \((u, v)\) 只會在 \(\operatorname{lca}\) 處合併,共有 \(\Theta(n^ 2)\) 個點對。
注意 \(k\) 需要和 \(n - 1\) 取 \(\min\)。
程式碼
#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
const int N = 5010;
const int mod = 998244353;
inline int add(int a, int b) { return a + b >= mod ? a + b - mod : a + b; }
inline int sub(int a, int b) { return a - b < 0 ? a - b + mod : a - b; }
inline int mul(int a, int b) { return 1ll * a * b % mod; }
inline void toadd(int& a, int b) { a = add(a, b); }
int n, k, fa[N];
int frac[N], ifrac[N], Inv[N];
inline int C(int n, int m) { return mul(frac[n], mul(ifrac[m], ifrac[n - m])); }
int f[N][N][2]; // j 個聯通塊,根所在的聯通塊有沒有被選中
int siz[N];
int F[N], G[N];
signed main() {
#ifndef XuYueming
freopen("kaisou.in", "r", stdin);
freopen("kaisou.out", "w", stdout);
#endif
scanf("%d%d", &n, &k), k = min(k, n - 1);
for (int i = 2; i <= n; ++i) scanf("%d", &fa[i]), ++fa[i];
frac[0] = ifrac[0] = 1;
for (int i = 1; i <= n; ++i) {
frac[i] = mul(frac[i - 1], i);
Inv[i] = i == 1 ? 1 : mul(mod - mod / i, Inv[mod % i]);
ifrac[i] = mul(ifrac[i - 1], Inv[i]);
f[i][1][0] = f[i][1][1] = 1;
siz[i] = 1;
}
for (int u = n; u >= 1; --u) {
int v = fa[u];
static int g[N][2];
memcpy(g, f[v], sizeof(g));
memset(f[v], 0x00, sizeof(f[v]));
for (int i = 1; i <= siz[v]; ++i)
for (int j = 1; j <= siz[u]; ++j) {
toadd(f[v][i + j][0], mul(f[u][j][1], g[i][0]));
toadd(f[v][i + j][1], mul(f[u][j][1], g[i][1]));
toadd(f[v][i + j - 1][0], mul(f[u][j][0], g[i][0]));
toadd(f[v][i + j - 1][1], add(mul(f[u][j][1], g[i][0]), mul(f[u][j][0], g[i][1])));
}
siz[v] += siz[u];
}
F[n - 1] = 1;
for (int i = n - 2, p = 1; i >= 0; --i) F[i] = mul(p, f[1][n - i][1]), p = mul(p, n);
// p 即為 n^{(n-x)-2}
int ans = 0;
for (int i = n - 1; i >= n - 1 - k; --i) {
for (int j = i; j <= n - 1; ++j) {
if ((j - i) & 1)
G[i] = sub(G[i], mul(C(j, i), F[j]));
else
G[i] = add(G[i], mul(C(j, i), F[j]));
}
ans = add(ans, G[i]);
}
printf("%d", ans);
return 0;
}