挺明顯的換根 DP。。
0x01
先考慮一下起點為 \(d\) 的答案該怎麼算。
我們發現可以欽定 \(d\) 為樹根,設 \(dp_i\) 表示以 \(i\) 為根的子樹中,以 \(i\) 為起點最多可以獲得的代價。由於有了進入節點的次數限制,我們肯定不能直接加和。不難發現這個限制其實就和點的度數有關,於是我們直接選擇 \(i\) 的兒子中,\(dp\) 值最大的 \(k_i-1\) 個就行了,排序即可。
0x02
我們嘗試根據起點為 \(d\) 時的求法擴充到其餘節點。
這裡就是換根 DP 了。我們記 \(f_i\) 表示以 \(i\) 為起點,且 \(i\) 只會往父親節點移動的最大代價。考慮用 \(f_{fa}\) 去更新 \(f_i\)。如果我們只能向 \(fa\) 移動,那麼 \(fa\) 肯定已經走過了一次,又因為我們的 \(fa\) 也在往上走,所以初始時 \(fa\) 走了兩次。也就是我們選擇 \(i\) 的兄弟節點中 \(dp\) 值最大的 \(k_{fa}-2\) 個節點,將它們的 \(dp\) 值加起來,最後再加上 \(f_{fa}\),就可以得到 \(dp_i\)。
可能有人要問了——為什麼 \(fa\) 非要往上走呢?為什麼不可以用 \(i\) 的兄弟節點中 \(dp\) 值最大的 \(k_{fa}-1\) 個呢??
因為我們必須要走到 \(d\) 節點至少一次,並且我們是以 \(d\) 為根進行 DP 的,那麼 \(fa\) 肯定就要往上走一次。也正是如此,我們在統計 \(d\) 的兒子的時候,才應該像上面這樣轉移,因為 \(fa\) 就是 \(d\),沒有必要往其它地方走了。
0x03
求出了 \(f_i\) 之後,我們就可以統計答案了。
我們列舉起點 \(s\)。首先要到達 \(d\) 至少一次,我們就加上一個 \(f_s\),然後再考慮將 \(k_s\) 給跑滿,所以我們還要加上 \(s\) 的兒子中 \(dp\) 值最大的 \(k_s-2\) 個的和。
比較最大值即可。
#include<bits/stdc++.h>
using namespace std;
const int MAXN=1e5+5;
int head[MAXN],nxt[MAXN<<1],to[MAXN<<1],val[MAXN<<1],tot;
void add(int x,int y,int z)
{
to[++tot]=y;
val[tot]=z;
nxt[tot]=head[x];
head[x]=tot;
}
int n,d;
int a[MAXN],in[MAXN];
int dp[MAXN],f[MAXN];
int stk[MAXN],cnt;
bool cmp(int x,int y){ return dp[x]<dp[y]; }
void dfs_first(int x,int fa)
{
for(int i=head[x];i;i=nxt[i])
{
if(to[i]==fa) continue;
dfs_first(to[i],x);
}
int tot=0;
for(int i=head[x];i;i=nxt[i])
{
if(to[i]==fa) continue;
dp[to[i]]+=val[i],stk[++tot]=to[i];
}
sort(stk+1,stk+tot+1,cmp);
for(int i=max(1,tot-a[x]+2);i<=tot;i++) dp[x]+=dp[stk[i]];
}
int maxx;
void dfs(int x,int fa)
{
if(a[x]==1) return;
cnt=0;
for(int i=head[x];i;i=nxt[i])
{
if(to[i]==fa) continue;
stk[++cnt]=to[i];
}
sort(stk+1,stk+cnt+1,cmp);
int now=0;
for(int i=cnt-a[x]+3;i<=cnt;i++) now+=dp[stk[i]];
maxx=max(maxx,f[x]+now);
for(int i=cnt-a[x]+3;i<=cnt;i++) f[stk[i]]=f[x]+now-dp[stk[i]]+dp[stk[max(1,cnt-a[x]+3)-1]];
for(int i=1;i<=cnt-a[x]+2;i++) f[stk[i]]=f[x]+now;
for(int i=head[x];i;i=nxt[i])
{
if(to[i]==fa) continue;
f[to[i]]+=val[i];
dfs(to[i],x);
}
}
int main()
{
// freopen("data.in","r",stdin);
// freopen("data.out","w",stdout);
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n>>d;
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
add(x,y,z),add(y,x,z);
in[x]++,in[y]++;
}
for(int i=1;i<=n;i++) cin>>a[i],a[i]=min(a[i],in[i]+1);
for(int i=head[d];i;i=nxt[i]) dfs_first(to[i],d);
cnt=0;
for(int i=head[d];i;i=nxt[i]) stk[++cnt]=to[i],dp[to[i]]+=val[i];
sort(stk+1,stk+cnt+1,cmp);
for(int i=max(1,cnt-a[d]+2);i<=cnt;i++) maxx+=dp[stk[i]];
for(int i=max(1,cnt-a[d]+2);i<=cnt;i++) f[stk[i]]=maxx-dp[stk[i]]+dp[stk[max(1,cnt-a[d]+2)-1]];
for(int i=1;i<=cnt-a[d]+1;i++) f[stk[i]]=maxx;
for(int i=head[d];i;i=nxt[i]) f[to[i]]+=val[i],dfs(to[i],d);
cout<<maxx;
return 0;
}