P2495 [SDOI2011] 消耗戰

Fire_Raku發表於2024-04-06

P2495 [SDOI2011] 消耗戰

虛樹最佳化 dp 模板題

考慮 \(m=1\)。只需要簡單的樹形 dp,設 \(f_i\) 表示 \(i\) 子樹中的關鍵點都到不了 \(i\) 點的最小代價。轉移列舉子節點 \(v\),有:

\(v\) 點為關鍵點,\(f_u=f_u+w(u,v)\)

否則,\(f_u=f_u+\min(f_v,w(u,v))\)

如果每次詢問都跑一遍,複雜度 \(O(nm)\)。考慮最佳化。

我們發現這題最關鍵的一點是,我們轉移時訪問的點很多都是無用的。事實上,我們只需要儲存關鍵點以及關鍵點的 \(lca\) 即可轉移。所以我們需要建出一棵新樹滿足這樣的要求。

虛樹,在原樹中保留關鍵點以及兩兩的公共祖先和樹根所構成的樹。如何建出虛樹?先將關鍵點按 \(dfs\) 序從小到大排序。我們考慮不斷用棧維護一條最右鏈,當列舉一個關鍵點 \(v\) 時,求出 \(rt=lca(s[top],v)\),有以下情況:

\(rt=s[top]\),說明 \(v\) 在當前最右鏈上,直接將 \(v\) 插入棧即可。

否則,考慮一直彈棧直到沒有點的深度大於 \(rt\),彈棧的同時連邊,最後再插入 \(v\)。這部分細節多,這裡只是概述簡要思想,具體要用圖才能說清楚。下圖為一般情況。

到現在,建立虛樹的程式碼呼之欲出。

void build() {
	std::sort(a + 1, a + k + 1, cmp); //按 dfs 序排序
	st[++top] = 1;
	for(int i = 1; i <= k; i++) {
		int rt = lca(a[i], st[top]);
		while(top && dep[st[top - 1]] >= dep[rt]) {
			add(st[top - 1], st[top]), top--;
		} //彈棧時連邊
		if(st[top] != rt) {
			add(rt, st[top]), st[top] = rt;
		} //特殊情況,rt 為新點,連邊後覆蓋 st[top]
		st[++top] = a[i]; //最後插入 v
	}
	while(top > 1) {
		add(st[top - 1], st[top]);
		top--;
	} //最後連上最右鏈
	top = 0;
}

在這題裡,虛樹的邊顯然是路徑上的最小值,倍增預處理即可。

建好虛樹後在虛樹上跑樹形 dp 即可。總複雜度是 \(O(\sum k\log \sum k)=O(n\log n)\)

#include <bits/stdc++.h>
#define pii std::pair<int, i64>
#define fi first
#define se second
#define pb push_back

typedef long long i64;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 250010;
int n, m, k, tot;
int a[N];
int anc[N][20], dfn[N], dep[N];
i64 mn[N][20];
std::vector<pii> V[N];
void dfs(int u, int fa) {
	anc[u][0] = fa;
	dfn[u] = ++tot;
	dep[u] = dep[fa] + 1;
	for(int j = 1; j <= 19; j++) {
		anc[u][j] = anc[anc[u][j - 1]][j - 1];
		mn[u][j] = std::min(mn[u][j - 1], mn[anc[u][j - 1]][j - 1]);
	}
	for(auto v : V[u]) {
		if(v.fi == fa) continue;
		mn[v.fi][0] = v.se;
		dfs(v.fi, u);
	}
}
int lca(int u, int v) {
	if(dep[u] < dep[v]) std::swap(u, v);
	for(int i = 19; i >= 0; i--) if(dep[anc[u][i]] >= dep[v]) u = anc[u][i];
	if(u == v) return u;
	for(int i = 19; i >= 0; i--) if(anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
	return anc[u][0];
}
int cnt;
int h[N];
struct node {
	int to, nxt;
	i64 w;
} e[N << 1];
void add(int u, int v, i64 w) {
	e[++cnt].to = v, e[cnt].nxt = h[u], e[cnt].w = w;
	h[u] = cnt;
}
bool cmp(int a, int b) {
	return dfn[a] < dfn[b];
}
i64 calc(int u, int v) {
	i64 ret = linf;
	for(int i = 19; i >= 0; i--) {
		if(dep[anc[u][i]] > dep[v]) ret = std::min(ret, mn[u][i]), u = anc[u][i]; 
	}
	return std::min(ret, mn[u][0]);
}
int st[N], top;
void build() {
	std::sort(a + 1, a + k + 1, cmp);
	st[++top] = 1;
	for(int i = 1; i <= k; i++) {
		int rt = lca(a[i], st[top]);
		while(top && dep[st[top - 1]] >= dep[rt]) {
			i64 dis = calc(st[top], st[top - 1]);
			add(st[top - 1], st[top], dis), top--;
		}
		if(st[top] != rt) {
			i64 dis = calc(st[top], rt);
			add(rt, st[top], dis), st[top] = rt;
		}
		st[++top] = a[i];
	}
	while(top > 1) {
		i64 dis = calc(st[top], st[top - 1]);
		add(st[top - 1], st[top], dis);
		top--;
	}
	top = 0;
}
bool vis[N];
i64 f[N];
void dp(int u, int fa) {
	for(int i = h[u]; i; i = e[i].nxt) {
		int v = e[i].to; i64 w = e[i].w;
		if(v == fa) continue;
		dp(v, u);
		if(vis[v]) f[u] += w;
		else f[u] += std::min(f[v], w);
		f[v] = 0, vis[v] = 0;
	}
	h[u] = 0;
}
void Solve() {
	std::cin >> n;
	for(int i = 1; i < n; i++) {
		int u, v, w;
		std::cin >> u >> v >> w;
		V[u].pb({v, w}), V[v].pb({u, w});
	}
	dfs(1, 0);
	std::cin >> m;
	while(m--) {
		std::cin >> k;
		for(int i = 1; i <= k; i++) {
			std::cin >> a[i];
			vis[a[i]] = 1;
		}
		build();
		dp(1, 0);
		std::cout << f[1] << "\n"; f[1] = 0;
		cnt = 0;
	}
}
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
	Solve();

	return 0;
}

相關文章