CF1039D You Are Given a Tree (樹形 dp + 貪心 + 根號分治)

Fire_Raku發表於2024-07-04

CF1039D You Are Given a Tree

樹形 dp + 貪心 + 根號分治

題目是一個經典問題,可以用樹形 dp 和貪心解決。設 \(f_u\) 表示以 \(u\) 節點為端點能夠剩下的最長路徑。考慮從葉子節點往上合併貪心,那麼如果能夠合併出包含 \(u\) 節點的大於等於 \(k\) 的路徑,那麼就合併, \(f_u=0\)

否則 \(f_u=\max(f_v+1)\)

如果每個 \(k\in [1,n]\) 都跑一次,那麼複雜度是 \(O(n^2)\)。考慮最佳化。

我們發現,隨著 \(k\) 增大,答案是一定不增的,並且假如當前列舉到 \(k\),那麼答案最大為 \(\lfloor\frac{n}{k}\rfloor\)。這樣除法的形式可以想到根號分治。一般的,如果 \(k\ge \sqrt{n}\),那麼答案 \(\lfloor\frac{n}{k}\rfloor\le\sqrt{n}\),也就是後面的答案取值只有 \(\sqrt{n}\)

結合上答案的單調性,不難想到二分每個段的右端點,這樣複雜度為 \(O(n\sqrt{n}\log n)\)

而對於 \(k\le\sqrt{n}\) 的部分,複雜度為 \(O(n\sqrt{n})\)。我們可以找到一個更好的值使得這兩部分儘可能均勻。

設這個值為 \(B\),那麼前面的複雜度為 \(O(nB)\),後面的複雜度為 \(O(\frac{n^2}{B}\log n)\),兩部分相等,得 \(B=\sqrt{n\log n}\)

所以總複雜度為 \(O(n\sqrt{n\log n})\)。但是這題很卡常,所以需要一些卡常技巧:

  1. 鏈式前向星存邊
  2. 快讀快寫
  3. 一遍 dfs 後將樹拍成 dfs 序後倒著遍歷,這也是最關鍵的卡常的地方。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back

using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e5 + 10;
int n, ans, k, g, tot, cnt;
int dfn[N], fat[N], mx[N], dmx[N];
struct node {
	int to, nxt;
} e[N << 1];
int f[N], h[N];
void add(int u, int v) {
	e[++cnt].to = v;
	e[cnt].nxt = h[u];
	h[u] = cnt;
}
void dfs(int u, int fa) { 
	dfn[++tot] = u; 
	fat[u] = fa; 
	for (int i = h[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v == fa) continue;
		dfs(v, u);
	}
}
void dfs2() {
	for(int i = tot; i >= 1; i--) { 
		int u = dfn[i];
		if(f[mx[u]] + f[dmx[u]] + 1 >= k) ans++, f[u] = 0;
		else f[u] = f[mx[u]] + 1;

		int fa = fat[u];
		if(f[mx[fa]] < f[u]) dmx[fa] = mx[fa], mx[fa] = u;
		else if(f[dmx[fa]] < f[u]) dmx[fa] = u; 
	}
}
bool check(int x) {
	for(int i = 1; i <= n; i++) f[i] = mx[i] = dmx[i] = 0;
	k = x, ans = 0;
	dfs2();
	if(ans == g) return 1;
	return 0;
}
int main() {
	scanf("%d", &n);
	for (int i = 1; i < n; i++) {
		int u, v;
		scanf("%d%d", &u, &v);
		add(u, v), add(v, u);
	}	

	dfs(1, 0);

	int n1 = n, m = 0;
	while(n1) n1 >>= 1, m++;
	m++;
	m = sqrt(1LL * n * m);

	for (int i = 1; i <= n; i++) {
		k = i;
		if(i <= m) {
			for (int j = 1; j <= n; j++) f[j] = mx[j] = dmx[j] = 0;
			dfs2();
			printf("%d\n", ans);
		} else {
			for (int j = 1; j <= n; j++) f[j] = mx[j] = dmx[j] = 0;
			dfs2(); g = ans;
			int l = i, r = n, mx = i;
			while(l <= r) {
				int mid = (l + r) >> 1;
				if(check(mid)) l = mid + 1, mx = mid;
				else r = mid - 1;
			}
			for (int j = i; j <= mx; j++) printf("%d\n", g);
			i = mx;
		}
		ans = 0;
	}

	return 0;
}

相關文章