「SWTR-4」Collecting Coins 題解

Supor__Shoop發表於2024-09-27

挺明顯的換根 DP。。

0x01

先考慮一下起點為 \(d\) 的答案該怎麼算。

我們發現可以欽定 \(d\) 為樹根,設 \(dp_i\) 表示以 \(i\) 為根的子樹中,以 \(i\) 為起點最多可以獲得的代價。由於有了進入節點的次數限制,我們肯定不能直接加和。不難發現這個限制其實就和點的度數有關,於是我們直接選擇 \(i\) 的兒子中,\(dp\) 值最大的 \(k_i-1\) 個就行了,排序即可。

0x02

我們嘗試根據起點為 \(d\) 時的求法擴充到其餘節點。

這裡就是換根 DP 了。我們記 \(f_i\) 表示以 \(i\) 為起點,且 \(i\) 只會往父親節點移動的最大代價。考慮用 \(f_{fa}\) 去更新 \(f_i\)。如果我們只能向 \(fa\) 移動,那麼 \(fa\) 肯定已經走過了一次,又因為我們的 \(fa\) 也在往上走,所以初始時 \(fa\) 走了兩次。也就是我們選擇 \(i\) 的兄弟節點中 \(dp\) 值最大的 \(k_{fa}-2\) 個節點,將它們的 \(dp\) 值加起來,最後再加上 \(f_{fa}\),就可以得到 \(dp_i\)

可能有人要問了——為什麼 \(fa\) 非要往上走呢?為什麼不可以用 \(i\) 的兄弟節點中 \(dp\) 值最大的 \(k_{fa}-1\) 個呢??

因為我們必須要走到 \(d\) 節點至少一次,並且我們是以 \(d\) 為根進行 DP 的,那麼 \(fa\) 肯定就要往上走一次。也正是如此,我們在統計 \(d\) 的兒子的時候,才應該像上面這樣轉移,因為 \(fa\) 就是 \(d\),沒有必要往其它地方走了。

0x03

求出了 \(f_i\) 之後,我們就可以統計答案了。

我們列舉起點 \(s\)。首先要到達 \(d\) 至少一次,我們就加上一個 \(f_s\),然後再考慮將 \(k_s\) 給跑滿,所以我們還要加上 \(s\) 的兒子中 \(dp\) 值最大的 \(k_s-2\) 個的和。

比較最大值即可。

#include<bits/stdc++.h>
using namespace std;
const int MAXN=1e5+5;
int head[MAXN],nxt[MAXN<<1],to[MAXN<<1],val[MAXN<<1],tot;
void add(int x,int y,int z)
{
	to[++tot]=y;
	val[tot]=z;
	nxt[tot]=head[x];
	head[x]=tot;
}
int n,d;
int a[MAXN],in[MAXN];
int dp[MAXN],f[MAXN];
int stk[MAXN],cnt;
bool cmp(int x,int y){ return dp[x]<dp[y]; }
void dfs_first(int x,int fa)
{
	for(int i=head[x];i;i=nxt[i])
	{
		if(to[i]==fa)	continue;
		dfs_first(to[i],x);
	}
	int tot=0;
	for(int i=head[x];i;i=nxt[i])
	{
		if(to[i]==fa)	continue;
		dp[to[i]]+=val[i],stk[++tot]=to[i];
	}
	sort(stk+1,stk+tot+1,cmp);
	for(int i=max(1,tot-a[x]+2);i<=tot;i++)	dp[x]+=dp[stk[i]];
}
int maxx;
void dfs(int x,int fa)
{
	if(a[x]==1)	return;
	cnt=0;
	for(int i=head[x];i;i=nxt[i])
	{
		if(to[i]==fa)	continue;
		stk[++cnt]=to[i];
	}
	sort(stk+1,stk+cnt+1,cmp);
	int now=0;
	for(int i=cnt-a[x]+3;i<=cnt;i++)	now+=dp[stk[i]];
	maxx=max(maxx,f[x]+now);
	for(int i=cnt-a[x]+3;i<=cnt;i++)	f[stk[i]]=f[x]+now-dp[stk[i]]+dp[stk[max(1,cnt-a[x]+3)-1]];
	for(int i=1;i<=cnt-a[x]+2;i++)	f[stk[i]]=f[x]+now;
	for(int i=head[x];i;i=nxt[i])
	{
		if(to[i]==fa)	continue;
		f[to[i]]+=val[i];
		dfs(to[i],x);
	}
}
int main()
{
//	freopen("data.in","r",stdin);
//	freopen("data.out","w",stdout);
	ios::sync_with_stdio(false);
	cin.tie(0),cout.tie(0);
	cin>>n>>d;
	for(int i=1;i<n;i++)
	{
		int x,y,z;
		cin>>x>>y>>z;
		add(x,y,z),add(y,x,z);
		in[x]++,in[y]++;
	}
	for(int i=1;i<=n;i++)	cin>>a[i],a[i]=min(a[i],in[i]+1);
	for(int i=head[d];i;i=nxt[i])	dfs_first(to[i],d);
	cnt=0;
	for(int i=head[d];i;i=nxt[i])	stk[++cnt]=to[i],dp[to[i]]+=val[i];
	sort(stk+1,stk+cnt+1,cmp);
	for(int i=max(1,cnt-a[d]+2);i<=cnt;i++)	maxx+=dp[stk[i]];
	for(int i=max(1,cnt-a[d]+2);i<=cnt;i++)	f[stk[i]]=maxx-dp[stk[i]]+dp[stk[max(1,cnt-a[d]+2)-1]];
	for(int i=1;i<=cnt-a[d]+1;i++)	f[stk[i]]=maxx;
	for(int i=head[d];i;i=nxt[i])	f[to[i]]+=val[i],dfs(to[i],d);
	cout<<maxx;
	return 0;
}