前言
好題,這個待定係數真的好用!
思路
說實話,看見這道題的第一反應真不是什麼 \(\texttt{min-max}\) 容斥,而是狀壓。
但是既然都放在連結裡面了,那我們就勉為其難的向這方面思考一下吧。
首先直接透過這個容斥轉化題意,也就是說,我們只需要求得,對於任意一個點的集合 \(S\),使得從 \(x\) 出發,想要第一次到達 \(S\) 中的任意一點,所需要的期望次數。(簡單來說就是透過該容斥將題意轉化為 \(E(\min(S))\) 。)
具體來說,從 \(0\) 到 \(2^n\) 列舉所有子集 \(S\),對於每個子集,考慮樹形 \(\text{DP}\)。
定義 \(dp_i\) 表示從 \(i\) 開始,到達 \(S\) 中的任意一點所需要的最少期望次數。
如果 \(i\in S\),顯然 \(dp_i=0\)。否則,顯然有 \(dp_i=\dfrac{dp_{fa_i}+\sum_{j\in son_i}dp_j}{deg_i} + 1\)。(其中,\(deg_i\) 表示 \(i\) 的度數)
但是顯然,這樣的轉移是有後效性的。
一種比較暴力的想法是,把所有的式子列出來,然後高斯消元,複雜度為 \(n^3\times 2^n\),而且是在模意義下進行,實現代價巨大,不具有可行性。
我們考慮一個新式的 \(\text{trick}\): 待定係數法。
我們發現,這個式子是齊次的(且是一次的),而且必然可以透過高斯消元得到一個 \(dp_i\) 無後效性的表示。故我們考慮待定係數,假設 \(dp_i=k_i\times dp_{fa_i}+b_i\)。
我們可以透過這個假設,將轉移式子中的 \(dp_j\) 換掉,然後進行一波代數變形:
於是,我們有:\(k_i=\dfrac{1}{(deg_i-\sum k_j)},b_i=\dfrac{\sum b_j+deg_i}{(deg_i-\sum k_j)}\)
可以發現,\(k_i,b_i\) 的求法不具有後效性,且如果我們從 \(x\) 直接向下遞迴,則 \(dp_x=b_x\),因為沒有父親。
求出每個 \(S\) 之後,我們顯然不能每輸入一個就容斥一遍,也不能直接列舉子集,所以高維字首和,根據容斥,\(S\) 大小為奇數時就是加,否則減。
程式碼
#include <bits/stdc++.h>
using namespace std;
#define maxn 20
#define mod 998244353
int ksm(int x, int y)
{
int sum = 1;
while(y)
{
if(y & 1) sum = 1ll * sum * x % mod;
y >>= 1, x = 1ll * x * x % mod;
}
return sum;
}
int inv(int x)
{
return ksm(x, mod - 2);
}
int n, q, x;
int fst[maxn], cnt;
struct node
{
int tar, nxt;
}arr[maxn << 1];
void adds(int x, int y)
{
arr[++cnt].tar = y, arr[cnt].nxt = fst[x], fst[x] = cnt;
}
bool belong[maxn];
int k[maxn], b[maxn];
int sum[1 << (maxn - 2)], ans[1 << maxn];
void dfs(int x, int last)
{
if(belong[x]) return;
int sumk = 0, sumb = 0, deg = 1;
if(last == 0) deg--;
for (int i = fst[x]; i; i = arr[i].nxt)
{
int j = arr[i].tar;
if(j == last) continue;
dfs(j, x);
sumk += k[j], sumb += b[j], deg++;
sumk %= mod, sumb %= mod;
}
k[x] = 1ll * inv(((deg - sumk) % mod + mod) % mod);
b[x] = 1ll * (deg + sumb) % mod * inv(((deg - sumk) % mod + mod) % mod) % mod;
}
int Cnt[1 << maxn];
void init()
{
for (int i = 1; i < (1 << n); ++i)
{
Cnt[i] = __builtin_popcount(i);
memset(belong, 0, sizeof(belong));
memset(k, 0, sizeof(k));
memset(b, 0, sizeof(b));
for (int j = 1; j <= n; ++j) if(i & (1 << j - 1)) belong[j] = true;
dfs(x, 0);
sum[i] = b[x] * (Cnt[i] & 1 ? 1 : -1);
sum[i] = (sum[i] % mod + mod) % mod;
}
for (int j = 1; j <= n; ++j)
{
for (int i = 0; i < (1 << n); ++i)
{
if(!(i & (1 << j - 1)))
sum[i + (1 << j - 1)] += sum[i], sum[i + (1 << j - 1)] %= mod;
}
}
}
int main()
{
cin >> n >> q >> x;
for (int i = 1; i < n; ++i)
{
int x, y;
cin >> x >> y;
adds(x, y);
adds(y, x);
}
init();
while(q--)
{
int m, now = 0;
cin >> m;
for (int i = 1; i <= m; ++i)
{
int x;
cin >> x;
now |= 1 << x - 1;
}
printf("%d\n", sum[now]);
}
}