P11324 【MX-S7-T2】「SMOI-R2」Speaker - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)
就是,複雜的分類討論。最核心的就是樹上倍增求鏈的最大值。不寫多了。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long LL;
const int N = 200010, M = N * 2, K = log2(N) + 2;
int n, m;
int h[N], e[M], ne[M], w[M], idx;
LL v[N];
LL depth[N], fa[N][K], maxv[N][K];
LL dist[N], max1[N], max2[N], max3[N];
LL mak1[N], mak2[N], mak3[N];
LL sum[N], res, kk;
void add(int a, int b, int c)
{
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++ ;
}
void dfs1(int u, int fas)
{
max1[u] = v[u];
for (int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if (fas == j) continue;
dist[j] = dist[u] + w[i];
dfs1(j, u);
// cout << u << ' ' << max1[u] << endl;
// cout << j << ' ' << max1[j] << endl;
if (max1[j] - 2 * w[i] >= max1[u])
{
max3[u] = max2[u], mak3[u] = mak2[u];
max2[u] = max1[u]; mak2[u] = mak1[u];
max1[u] = max1[j] - 2 * w[i], mak1[u] = j;
// cout << max1[u] << endl;
}
else if (max1[j] - 2 * w[i] >= max2[u])
{
max3[u] = max2[u], mak3[u] = mak2[u];
max2[u] = max1[j] - 2 * w[i], mak2[u] = j;
}
else if (max1[j] - 2 * w[i] >= max3[u])
{
max3[u] = max1[j] - 2 * w[i];
mak3[u] = j;
}
// cout << 'a';
}
}
void dfs2(int u, int fas)
{
depth[u] = depth[fas] + 1;
fa[u][0] = fas;
if (mak1[fas] != u) maxv[u][0] = max(v[u], max1[fas]);
else maxv[u][0] = max(v[u], max2[fas]);
for (int j = 1; j < K; j ++ )
{
fa[u][j] = fa[fa[u][j - 1]][j - 1];
maxv[u][j] = max(maxv[u][j - 1], maxv[fa[u][j - 1]][j - 1]);
}
for (int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if (j == fas) continue;
if (mak1[u] != j) sum[j] = max(sum[u], max1[u]) - 2 * w[i];
else sum[j] = max(sum[u], max2[u]) - 2 * w[i];
dfs2(j, u);
}
}
int lca(int a, int b)
{
if (depth[a] < depth[b]) swap(a, b);
for (int j = K - 1; j >= 0; j -- )
if (depth[fa[a][j]] >= depth[b])
{
if (depth[fa[a][j]] == depth[b]) kk = a;
res = max(res, maxv[a][j]);
a = fa[a][j];
}
if (a == b)
{
res = max(res, v[a]);
return a;
}
for (int j = K - 1; j >= 0; j -- )
if (fa[a][j] != fa[b][j])
{
res = max({res, maxv[a][j], maxv[b][j]});
a = fa[a][j];
b = fa[b][j];
}
kk = a;
int u = fa[a][0];
if (a != mak1[u] && b != mak1[u]) res = max(res, max1[u]);
else if (a != mak2[u] && b != mak2[u]) res = max(res, max2[u]);
else res = max(res, max3[u]);
return u;
}
int main()
{
// freopen("speaker4.in", "r", stdin);
// freopen("speaker4.out", "w", stdout);
cin >> n >> m;
for (int i = 1; i <= n; i ++ ) scanf("%lld", &v[i]);
memset(h, -1, sizeof h);
for (int i = 1; i < n; i ++ )
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c);
add(b, a, c);
// cout << a << ' ' << b << ' ' << c << endl;
}
dfs1(1, 0);
dfs2(1, 0);
while (m -- )
{
int a, b;
res = 0;
scanf("%d%d", &a, &b);
int p = lca(a, b);
// printf("%d %d %d\n", a, b, p);
LL ans = max({sum[p], res});
if (p == a)
{
if (mak1[a] != kk) ans = max({ans, max1[b], max1[a]});
else ans = max({ans, max1[b], max2[a]});
}
else if (p == b)
{
if (mak1[b] != kk) ans = max({ans, max1[a], max1[b]});
else ans = max({ans, max1[a], max2[b]});
}
else
{
ans = max({ans, max1[a], max1[b]});
}
printf("%lld\n", ans - (dist[a] + dist[b] - 2 * dist[p]) + v[a] + v[b]);
// printf("%lld %lld %lld %lld %lld\n", sum[p], max1[a], max1[b], res, ans);
}
return 0;
}