「PKUWC2018」隨機遊走 題解

Saltyfish6發表於2024-03-10

前言

好題,這個待定係數真的好用!

思路

說實話,看見這道題的第一反應真不是什麼 \(\texttt{min-max}\) 容斥,而是狀壓。

但是既然都放在連結裡面了,那我們就勉為其難的向這方面思考一下吧。

首先直接透過這個容斥轉化題意,也就是說,我們只需要求得,對於任意一個點的集合 \(S\),使得從 \(x\) 出發,想要第一次到達 \(S\) 中的任意一點,所需要的期望次數。(簡單來說就是透過該容斥將題意轉化為 \(E(\min(S))\) 。)

具體來說,從 \(0\)\(2^n\) 列舉所有子集 \(S\),對於每個子集,考慮樹形 \(\text{DP}\)

定義 \(dp_i\) 表示從 \(i\) 開始,到達 \(S\) 中的任意一點所需要的最少期望次數。

如果 \(i\in S\),顯然 \(dp_i=0\)。否則,顯然有 \(dp_i=\dfrac{dp_{fa_i}+\sum_{j\in son_i}dp_j}{deg_i} + 1\)。(其中,\(deg_i\) 表示 \(i\) 的度數)

但是顯然,這樣的轉移是有後效性的。

一種比較暴力的想法是,把所有的式子列出來,然後高斯消元,複雜度為 \(n^3\times 2^n\),而且是在模意義下進行,實現代價巨大,不具有可行性。

我們考慮一個新式的 \(\text{trick}\): 待定係數法。

我們發現,這個式子是齊次的(且是一次的),而且必然可以透過高斯消元得到一個 \(dp_i\) 無後效性的表示。故我們考慮待定係數,假設 \(dp_i=k_i\times dp_{fa_i}+b_i\)

我們可以透過這個假設,將轉移式子中的 \(dp_j\) 換掉,然後進行一波代數變形:

\[\begin{aligned} dp_i&=\dfrac{dp_{fa_i}+\left(\sum_{j\in son_i}k_j\right)\times dp_i+\sum b_j}{deg_i}+1 \\ (deg_i-\sum k_j)\times dp_i&=dp_{fa_i}+\sum b_j+deg_i\\ dp_i&=\dfrac{1}{(deg_i-\sum k_j)}\times dp_{fa_i}+\dfrac{\sum b_j+deg_i}{(deg_i-\sum k_j)} \end{aligned} \]

於是,我們有:\(k_i=\dfrac{1}{(deg_i-\sum k_j)},b_i=\dfrac{\sum b_j+deg_i}{(deg_i-\sum k_j)}\)

可以發現,\(k_i,b_i\) 的求法不具有後效性,且如果我們從 \(x\) 直接向下遞迴,則 \(dp_x=b_x\),因為沒有父親。

求出每個 \(S\) 之後,我們顯然不能每輸入一個就容斥一遍,也不能直接列舉子集,所以高維字首和,根據容斥,\(S\) 大小為奇數時就是加,否則減。

程式碼

#include <bits/stdc++.h> 
using namespace std;
#define maxn 20
#define mod 998244353
int ksm(int x, int y)
{
	int sum = 1;
	while(y)
	{
		if(y & 1) sum = 1ll * sum * x % mod;
		y >>= 1, x = 1ll * x * x % mod;
	}	
	return sum;
}
int inv(int x)
{
	return ksm(x, mod - 2);
}
int n, q, x;
int fst[maxn], cnt;
struct node
{
	int tar, nxt;
}arr[maxn << 1];
void adds(int x, int y)
{
	arr[++cnt].tar = y, arr[cnt].nxt = fst[x], fst[x] = cnt;
}
bool belong[maxn];
int k[maxn], b[maxn];
int sum[1 << (maxn - 2)], ans[1 << maxn];
void dfs(int x, int last)
{
	if(belong[x]) return;
	int sumk = 0, sumb = 0, deg = 1;
	if(last == 0) deg--;
	for (int i = fst[x]; i; i = arr[i].nxt)
	{
		int j = arr[i].tar;
		if(j == last) continue;
		dfs(j, x);
		sumk += k[j], sumb += b[j], deg++;
		sumk %= mod, sumb %= mod;
	}
	k[x] = 1ll * inv(((deg - sumk) % mod + mod) % mod);
	b[x] = 1ll * (deg + sumb) % mod * inv(((deg - sumk) % mod + mod) % mod) % mod;
}
int Cnt[1 << maxn];
void init()
{
	for (int i = 1; i < (1 << n); ++i)
	{
		Cnt[i] = __builtin_popcount(i);
		memset(belong, 0, sizeof(belong));
		memset(k, 0, sizeof(k));
		memset(b, 0, sizeof(b));
		for (int j = 1; j <= n; ++j) if(i & (1 << j - 1)) belong[j] = true;
		dfs(x, 0);
		sum[i] = b[x] * (Cnt[i] & 1 ? 1 : -1);
		sum[i] = (sum[i] % mod + mod) % mod;
	}
	for (int j = 1; j <= n; ++j)
	{
		for (int i = 0; i < (1 << n); ++i)
		{
			if(!(i & (1 << j - 1)))
			sum[i + (1 << j - 1)] += sum[i], sum[i + (1 << j - 1)] %= mod;
		}
	}
}
int main()
{
	cin >> n >> q >> x;
	for (int i = 1; i < n; ++i)
	{
		int x, y;
		cin >> x >> y;
		adds(x, y);
		adds(y, x);
	}
	init();
	while(q--)
	{
		int m, now = 0;
		cin >> m;
		for (int i = 1; i <= m; ++i)
		{
			int x;
			cin >> x;
			now |= 1 << x - 1;
		}
		printf("%d\n", sum[now]);
	}
}

相關文章