題目
傳送門
\(d_i\) 表示初始使用的能量,如果你處理了這個點,就可以把相鄰的點所需的能量減去 \(c_i\)
思路
前 \(50\%\) 的資料為 \(\max\{c_i\}=1,n\le10^5\)
因為最大的 \(c_i\) 只有 \(1\),我們可以考慮貪心
首先,我們肯定先處理 \(c_i=1\) 的點,我們要考慮處理的順序,不管順序怎麼變,遲早會把該節省的節省掉
int sum = 0;
for (int i = 1; i <= n; ++i)
{
if (c[i] == 0)
continue;
vis[i] = 1;
for (int v : e[i])
if (!vis[v]) --d[v];
}
for (int i = 1; i <= n; ++i)
sum += max(d[i], 0);
後 \(50\%\) 的資料,我們就要考慮使用樹形dp (就是個揹包)
定義
\(f[u][0/1]\) 表示當前節點是否在父節點之前處理完
\(g[i][0/1]\) 表示節省 \(i\) 元的最小花費
對於 \(g\) 的詳細操作
int sz = c[fa];
for (int i = 0; i <= sz; ++i) g[i][0] = g[i][1] = inf;
g[0][1] = g[c[fa]][0] = 0;
for (int v : e[u])
{
if (v == fa) continue;
for (int i = sz + 1; i <= sz + c[v]; ++i) g[i][0] = g[i][1] = inf;
sz += c[v];
for (int i = sz; i >= 0; --i)
{
if (i < c[v])
{
g[i][0] += f[v][0];
g[i][1] += f[v][0];
}else
{
g[i][0] = min(f[v][0] + g[i][0], f[v][1] + g[i - c[v]][0]);
g[i][1] = min(f[v][0] + g[i][1], f[v][1] + g[i - c[v]][1]);
}
}
}
得到\(g\)之後就好處理了,我們可寫出如下轉移
\(\displaystyle f[u][0] = \min_{i=1}^{sz}\{g[i][0]+\max\{0,d[u]-i\}\}\)
\(\displaystyle f[u][1] = \min_{i=1}^{sz}\{g[i][1]+\max\{0,d[u]-i\}\}\)
程式碼
#include <bits/stdc++.h>
#define ll long long
#define PII pair<int, int>
using namespace std;
const int inf = 0x3f3f3f3f;
const int MOD = 1e9 + 7, N = 1e5 + 5;
int n, d[N], c[N];
vector<int> e[N];
namespace pts1
{
int f[N][2], g[N][2];
void dfs(int u, int fa)
{
for (int v: e[u])
{
if (v == fa) continue;
dfs(v, u);
}
int sz = c[fa];
for (int i = 0; i <= sz; ++i) g[i][0] = g[i][1] = inf;
g[0][1] = g[c[fa]][0] = 0;
for (int v: e[u])
{
if (v == fa) continue;
for (int i = sz + 1; i <= sz + c[v]; ++i) g[i][0] = g[i][1] = inf;
sz += c[v];
for (int i = sz; i >= 0; --i)
{
if (i < c[v])
{
g[i][0] += f[v][0];
g[i][1] += f[v][0];
} else
{
g[i][0] = min(f[v][0] + g[i][0], f[v][1] + g[i - c[v]][0]);
g[i][1] = min(f[v][0] + g[i][1], f[v][1] + g[i - c[v]][1]);
}
}
}
f[u][0] = g[0][0] + d[u], f[u][1] = g[0][1] + d[u];
for (int i = 1; i <= sz; ++i)
{
f[u][0] = min(f[u][0], g[i][0] + max(0, d[u] - i));
f[u][1] = min(f[u][1], g[i][1] + max(0, d[u] - i));
}
}
void solve()
{
dfs(1, 0);
printf("%d", f[1][1]);
}
};
namespace pts2
{
int vis[N];
void solve()
{
int sum = 0;
for (int i = 1; i <= n; ++i)
{
if (c[i] == 0)
continue;
vis[i] = 1;
for (int v: e[i])
if (!vis[v]) --d[v];
}
for (int i = 1; i <= n; ++i)
sum += max(d[i], 0);
printf("%d", sum);
}
};
signed main()
{
int maxx = 0;
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", &d[i]);
for (int i = 1; i <= n; ++i)
scanf("%d", &c[i]), maxx = max(maxx, c[i]);
for (int i = 1, u, v; i < n; ++i)
{
scanf("%d%d", &u, &v);
e[u].push_back(v);
e[v].push_back(u);
}
if (maxx <= 1) pts2::solve();
else pts1::solve();
return 0;
}