思路
首先可以發現這個期望其實是假的,我們只需要把所有方案的答案加起來,最後除以 \((\frac{n(n-1)}{2})^2\) 即可,現在考慮如何統計所有方案的答案。
我們先考慮一條路徑的方案數:假設存在一條從 \(x\) 到 \(y\) 的公共路徑,其中 \(x\) 是 \(y\) 的祖先,那麼小紅和小藍分別選擇的路徑,其中一邊的端點肯定在 \(y\) 的子樹內,總共的方案數為 \(siz[y]^2\)。但是這並不正確,我們不妨記 \(y\) 的兒子為 \(y_1, y_2,...,y_m\),那麼如果選擇的端點都落在 \(y_1\) 內,那麼實際的公共路徑就變為了 \(x \rightarrow y_1\),所以實際的方案數是 \(siz[y]^2- \sum_{y_i \in son_y}siz[y_i]^2\),下面我們記 \(siz2[y]\) 為 \(\sum_{y_i \in son_y}siz[y_i]^2\)。
知道這個之後,我們可以考慮樹形dp,由於 \((a+1)^2=a^2+2a+1\),所以我們可以考慮維護所有路徑的 \(i\) 次項和,即方案數乘路徑長度的 \(i\) 次方。
記 \(f[x][0/1/2]\) 為 \(x\) 子樹內所有到 \(x\) 的路徑的 \(0/1/2\) 次項之和,現在考慮轉移:記 \(y\) 為 \(x\) 的一個兒子,現在我們要把 \(y\) 向 \(x\) 合併,首先因為會加入 \(x-y\) 這條邊,所以所有邊的長度都會加一,即:
同時 \(x-y\) 這條邊也要算進我們的方案數,即:
這樣處理完之後直接加到 \(f[x]\) 裡即可。
考慮完轉移,現在我們考慮如何計算答案,同樣是 \(y\) 向 \(x\) 轉移的過程,我f們可以從 \(x\) 和 \(y\) 裡面分別選出一條路徑拼成一條完整的路徑,記兩條路徑分別為 \(h_1,h_2\),同時對應的方案數為 \(g(h_1), g(h_2)\),有:
除此之外,\(y\) 中的路徑也可以單獨成為一條可行的路徑,這種情況下 \(x\) 側端點的方案數為
即兩個端點不能都在同一個 \(x\) 的兒子的子樹內,也不能都在 \(x\) 的父親之外,這裡可以自己畫畫圖理解一下。
到這裡就把這道題解決了,更多細節見程式碼
程式碼
#include <bits/stdc++.h>
using i64 = long long;
constexpr int P = 998244353;
i64 power(i64 a, i64 b)
{
i64 res = 1;
for( ; b; b >>= 1, a = a * a % P)
if(b & 1) res = res * a % P;
return res;
}
void solve()
{
int n; std::cin >> n;
std::vector<std::vector<int>> adj(n);
for(int i = 1; i < n; i++)
{
int u, v; std::cin >> u >> v;
u--, v--;
adj[u].emplace_back(v);
adj[v].emplace_back(u);
}
std::vector<i64> siz(n), siz2(n);
auto init = [&](auto init, int x, int fa) -> void
{
siz[x] = 1;
for(auto y : adj[x])
{
if(y == fa) continue;
init(init, y, x);
siz[x] += siz[y];
siz2[x] += siz[y] * siz[y];
}
};
init(init, 0, -1);
std::vector f(n, std::vector<i64>(3));
i64 ans = 0;
auto dfs = [&](auto dfs, int x, int fa) -> void
{
for(auto y : adj[x])
{
if(y == fa) continue;
dfs(dfs, y, x);
f[y][2] = (f[y][2] + 2 * f[y][1] + f[y][0]) % P;
f[y][1] = (f[y][1] + f[y][0]) % P;
for(int i = 0; i < 3; i++) f[y][i] = (f[y][i] + siz[y] * siz[y] - siz2[y]) % P;
i64 res = (f[y][0] * f[x][2] + f[y][2] * f[x][0] + 2 * f[y][1] * f[x][1]) % P;
i64 res2 = f[y][2] * ((n - siz[y]) * (n - siz[y]) - (siz2[x] - siz[y] * siz[y]) - (n - siz[x]) * (n - siz[x]));
ans = (ans + res + res2) % P;
for(int i = 0; i < 3; i++) f[x][i] = (f[x][i] + f[y][i]) % P;
}
};
dfs(dfs, 0, -1);
i64 inv = power(1LL * n * (n - 1) / 2 % P, P - 2);
ans = ans * inv % P * inv % P;
std::cout << ans << "\n";
}
int main()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t; std::cin >> t;
while(t--) solve();
return 0;
}