[Codeforces 1111E] Tree(虛樹+二項式反演)

WAautomaton發表於2019-02-20

題目連結

題目大意

給定一棵樹,有一些詢問。每次詢問給出kk個點和兩個數m,rm,r,表示讓原樹以rr為根,把這kk個點分成至多mm組,每組內不存在一個點是另一個點的祖先。求方案數膜1000000007.
n,Q105,k105,mmin(k,300)n,Q\le 10^5,\sum k\le 10^5,m\le min(k,300)

題解

顯然先建虛樹,並且按照給定根重新遍歷虛樹。剛開始SB的我想了好久怎麼重新確定虛樹中誰是誰的祖先……後來才發現直接把rr加進去一起建虛樹就行了qaq。
然後,看資料範圍似乎是個O(km)O(km)的做法?想了一會兒樹形dp,感覺不太可行。那就估計是組合數學了。
先不考慮組與組之間無區別的問題(即兩組分別為{1},{2}和{2},{1}實際上是相同的情況),我們給每個組設定一個編號。遍歷虛樹,如果某個點向上有xx個祖先,那麼它可以選的編號有mxm-x種,乘起來即可。
顯然這樣會重複,我們考慮去重。不妨令f(m)f(m)表示剛剛算出的答案,g(m)g(m)表示恰好分成mm非空無區別組的方案數。那麼:
f(m)=i=1m(mi)g(i)i!f(m)=\sum_{i=1}^m\binom mi g(i)\cdot i!
二項式反演即可得到:
g(m)=1i!i=1m(1)mi(mi)f(i)g(m)=\frac{1}{i!}\sum_{i=1}^m(-1)^{m-i}\binom mi f(i)
於是我們可以在O(km)O(km)的時間內算出所有的ff,利用ffO(m2)O(km)O(m^2)\le O(km)的時間內算出所有的gg,直接求和就是答案。

#include <bits/stdc++.h>
namespace IOStream {
	const int MAXR = 1 << 23;
	char _READ_[MAXR], _PRINT_[MAXR];
	int _READ_POS_, _PRINT_POS_, _READ_LEN_;
	inline char readc() {
	#ifndef ONLINE_JUDGE
		return getchar();
	#endif
		if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
		char c = _READ_[_READ_POS_++];
		if (_READ_POS_ == MAXR) _READ_POS_ = 0;
		if (_READ_POS_ > _READ_LEN_) return 0;
		return c;
	}
	template<typename T> inline void read(T &x) {
		x = 0; register int flag = 1, c;
		while (((c = readc()) < '0' || c > '9') && c != '-');
		if (c == '-') flag = -1; else x = c - '0';
		while ((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
		x *= flag;
	}
	template<typename T1, typename ...T2> inline void read(T1 &a, T2 &...x) {
		read(a), read(x...);
	}
	inline int reads(char *s) {
		register int len = 0, c;
		while (isspace(c = readc()) || !c);
		s[len++] = c;
		while (!isspace(c = readc()) && c) s[len++] = c;
		s[len] = 0;
		return len;
	}
	inline void ioflush() {
		fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0;
		fflush(stdout);
	}
	inline void printc(char c) {
		_PRINT_[_PRINT_POS_++] = c;
		if (_PRINT_POS_ == MAXR) ioflush();
	}
	inline void prints(char *s) {
		for (int i = 0; s[i]; i++) printc(s[i]);
	}
	template<typename T> inline void print(T x, char c = '\n') {
		if (x < 0) printc('-'), x = -x;
		if (x) {
			static char sta[20];
			register int tp = 0;
			for (; x; x /= 10) sta[tp++] = x % 10 + '0';
			while (tp > 0) printc(sta[--tp]);
		} else printc('0');
		printc(c);
	}
	template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
		print(x, ' '), print(y...);
	}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
typedef pair<int, int> P;
#define cls(a) memset(a, 0, sizeof(a))

const int MAXN = 100005, MAXM = 200005, MOD = 1000000007;
struct Graph { int to, next; } gra[MAXM];
struct Edge { int to, val, next; } edge[MAXM];
int hd[MAXN], st[20][MAXM], beg[MAXN], dep[MAXN], sta[MAXN], ed[MAXN];
int lg[MAXM], head[MAXN], arr[MAXN], vis[MAXN], sz[MAXN], n, m, tot;
void addgra(int u, int v) {
    gra[++tot] = (Graph) { v, hd[u] };
    hd[u] = tot;
}
void addedge(int u, int v, int w) {
    edge[++tot] = (Edge) { v, w, head[u] };
    head[u] = tot;
    edge[++tot] = (Edge) { u, w, head[v] };
    head[v] = tot;
    //printf("%d %d %d\n", u, v, w);
}
void dfs1(int u, int fa) {
    dep[st[0][beg[u] = ++tot] = u] = dep[fa] + 1;
    sz[u] = 1;
    for (int i = hd[u]; i; i = gra[i].next) {
        int v = gra[i].to;
        if (v != fa) dfs1(v, st[0][++tot] = u), sz[u] += sz[v];
    }
    ed[u] = tot;
}
int get_min(int a, int b) { return dep[a] < dep[b] ? a : b; }
int get_lca(int a, int b) {
    a = beg[a], b = beg[b];
    if (a > b) swap(a, b);
    int l = lg[b - a + 1];
    return get_min(st[l][a], st[l][b - (1 << l) + 1]);
}
bool cmp(const int &a, const int &b) { return beg[a] < beg[b]; }
int q, r, mm;
ll C[305][305], f[305], fac[305], rev[305];
ll modpow(ll a, int b) {
	ll res = 1;
	for (; b; b >>= 1) {
		if (b & 1) res = res * a % MOD;
		a = a * a % MOD;
	}
	return res;
}
void dfs4(int u, int fa) {
	for (int &i = head[u]; i; i = edge[i].next) {
		int v = edge[i].to;
		if (v != fa) dfs4(v, u);
	}
}
void dfs3(int u, int fa, int d, ll &ff) {
	for (int i = head[u]; i; i = edge[i].next) {
		int v = edge[i].to;
		if (v == fa) continue;
		dfs3(v, u, d - vis[u], ff);
	}
	if (vis[u]) (ff *= d) %= MOD;
}
int main() {
	C[0][0] = 1;
	for (int i = fac[0] = 1; i <= 300; i++) {
		fac[i] = fac[i - 1] * i % MOD;
		C[i][0] = 1;
		for (int j = 1; j <= i; j++)
			C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % MOD;
	}
	rev[300] = modpow(fac[300], MOD - 2);
	for (int i = 300; i > 0; i--) rev[i - 1] = rev[i] * i % MOD;
    read(n, m);
    for (int i = 1; i < n; i++) {
        int u, v; read(u, v);
        addgra(u, v);
        addgra(v, u);
    }
    dfs1(1, tot = 0);
    for (int i = 2; i <= tot; i++) lg[i] = lg[i >> 1] + 1;
    for (int i = 1; i < 20; i++)
    for (int j = 1; j + (1 << i) - 1 <= tot; j++)
        st[i][j] = get_min(st[i - 1][j], st[i - 1][j + (1 << i >> 1)]);
    while (m--) {
        int top = tot = 0, flag = 0; read(q, mm, r);
        for (int i = 1; i <= q; i++) {
        	read(arr[i]), vis[hd[i] = arr[i]] = 1;
        	if (arr[i] == r) flag = 1;
        }
        if (!flag) arr[++q] = r;
        sort(arr + 1, arr + 1 + q, cmp);
        sta[++top] = 1;
        for (int i = arr[1] == 1 ? 2 : 1; i <= q; i++) {
            int l = get_lca(sta[top], arr[i]);
            for (; top > 1 && dep[sta[top - 1]] > dep[l]; top--)
                addedge(sta[top - 1], sta[top], dep[sta[top]] - dep[sta[top - 1]]);
            if (dep[sta[top]] > dep[l]) addedge(l, sta[top], dep[sta[top]] - dep[l]), --top;
            if (dep[sta[top]] < dep[l]) sta[++top] = l;
            sta[++top] = arr[i];
        }
        for (; top > 1; top--) addedge(sta[top - 1], sta[top], dep[sta[top]] - dep[sta[top - 1]]);
        ll res = 0;
        for (int i = 1; i <= mm; i++) {
        	f[i] = 1;
        	dfs3(r, 0, i, f[i]);
        	ll sum = 0;
        	for (int j = 1; j <= i; j++) {
        		if ((i - j) & 1) (sum -= C[i][j] * f[j]) %= MOD;
        		else (sum += C[i][j] * f[j]) %= MOD;
        	}
        	(res += sum * rev[i]) %= MOD;
        }
        for (int i = 1; i <= q; i++) vis[arr[i]] = 0;
        dfs4(r, 0);
        print((res + MOD) % MOD);
    }
    ioflush();
    return 0;
}

相關文章