CF1039D You Are Given a Tree
樹形 dp + 貪心 + 根號分治
題目是一個經典問題,可以用樹形 dp 和貪心解決。設 \(f_u\) 表示以 \(u\) 節點為端點能夠剩下的最長路徑。考慮從葉子節點往上合併貪心,那麼如果能夠合併出包含 \(u\) 節點的大於等於 \(k\) 的路徑,那麼就合併, \(f_u=0\);
否則 \(f_u=\max(f_v+1)\)。
如果每個 \(k\in [1,n]\) 都跑一次,那麼複雜度是 \(O(n^2)\)。考慮最佳化。
我們發現,隨著 \(k\) 增大,答案是一定不增的,並且假如當前列舉到 \(k\),那麼答案最大為 \(\lfloor\frac{n}{k}\rfloor\)。這樣除法的形式可以想到根號分治。一般的,如果 \(k\ge \sqrt{n}\),那麼答案 \(\lfloor\frac{n}{k}\rfloor\le\sqrt{n}\),也就是後面的答案取值只有 \(\sqrt{n}\) 種。
結合上答案的單調性,不難想到二分每個段的右端點,這樣複雜度為 \(O(n\sqrt{n}\log n)\)。
而對於 \(k\le\sqrt{n}\) 的部分,複雜度為 \(O(n\sqrt{n})\)。我們可以找到一個更好的值使得這兩部分儘可能均勻。
設這個值為 \(B\),那麼前面的複雜度為 \(O(nB)\),後面的複雜度為 \(O(\frac{n^2}{B}\log n)\),兩部分相等,得 \(B=\sqrt{n\log n}\)。
所以總複雜度為 \(O(n\sqrt{n\log n})\)。但是這題很卡常,所以需要一些卡常技巧:
- 鏈式前向星存邊
- 快讀快寫
- 一遍 dfs 後將樹拍成 dfs 序後倒著遍歷,這也是最關鍵的卡常的地方。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define mk std::make_pair
#define fi first
#define se second
#define pb push_back
using i64 = long long;
using ull = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e5 + 10;
int n, ans, k, g, tot, cnt;
int dfn[N], fat[N], mx[N], dmx[N];
struct node {
int to, nxt;
} e[N << 1];
int f[N], h[N];
void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = h[u];
h[u] = cnt;
}
void dfs(int u, int fa) {
dfn[++tot] = u;
fat[u] = fa;
for (int i = h[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa) continue;
dfs(v, u);
}
}
void dfs2() {
for(int i = tot; i >= 1; i--) {
int u = dfn[i];
if(f[mx[u]] + f[dmx[u]] + 1 >= k) ans++, f[u] = 0;
else f[u] = f[mx[u]] + 1;
int fa = fat[u];
if(f[mx[fa]] < f[u]) dmx[fa] = mx[fa], mx[fa] = u;
else if(f[dmx[fa]] < f[u]) dmx[fa] = u;
}
}
bool check(int x) {
for(int i = 1; i <= n; i++) f[i] = mx[i] = dmx[i] = 0;
k = x, ans = 0;
dfs2();
if(ans == g) return 1;
return 0;
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
dfs(1, 0);
int n1 = n, m = 0;
while(n1) n1 >>= 1, m++;
m++;
m = sqrt(1LL * n * m);
for (int i = 1; i <= n; i++) {
k = i;
if(i <= m) {
for (int j = 1; j <= n; j++) f[j] = mx[j] = dmx[j] = 0;
dfs2();
printf("%d\n", ans);
} else {
for (int j = 1; j <= n; j++) f[j] = mx[j] = dmx[j] = 0;
dfs2(); g = ans;
int l = i, r = n, mx = i;
while(l <= r) {
int mid = (l + r) >> 1;
if(check(mid)) l = mid + 1, mx = mid;
else r = mid - 1;
}
for (int j = i; j <= mx; j++) printf("%d\n", g);
i = mx;
}
ans = 0;
}
return 0;
}