P11324 【MX-S7-T2】「SMOI-R2」Speaker

blind5883發表於2024-11-24

P11324 【MX-S7-T2】「SMOI-R2」Speaker - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)

就是,複雜的分類討論。最核心的就是樹上倍增求鏈的最大值。不寫多了。

#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>

using namespace std;

typedef long long LL;

const int N = 200010, M = N * 2, K = log2(N) + 2;

int n, m;
int h[N], e[M], ne[M], w[M], idx;
LL v[N];
LL depth[N], fa[N][K], maxv[N][K]; 
LL dist[N], max1[N], max2[N], max3[N];
LL mak1[N], mak2[N], mak3[N];
LL sum[N], res, kk;

void add(int a, int b, int c)
{
    e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++ ; 
}

void dfs1(int u, int fas)
{
    max1[u] = v[u];
    
    for (int i = h[u]; i != -1; i = ne[i])
    {
        int j = e[i];
        if (fas == j) continue;
        dist[j] = dist[u] + w[i];
        dfs1(j, u);
        // cout << u << ' ' << max1[u] << endl;
        // cout << j << ' ' << max1[j] << endl; 
        if (max1[j] - 2 * w[i] >= max1[u])
        {
            max3[u] = max2[u], mak3[u] = mak2[u];
            max2[u] = max1[u]; mak2[u] = mak1[u];
            max1[u] = max1[j] - 2 * w[i], mak1[u] = j;
            // cout << max1[u] << endl;
        }
        else if (max1[j] - 2 * w[i] >= max2[u])
        {
            max3[u] = max2[u], mak3[u] = mak2[u];
            max2[u] = max1[j] - 2 * w[i], mak2[u] = j;
        }
        else if (max1[j] - 2 * w[i] >= max3[u]) 
        {
            max3[u] = max1[j] - 2 * w[i];
            mak3[u] = j;
        }
        // cout << 'a';
        
    }
}

void dfs2(int u, int fas)
{
    depth[u] = depth[fas] + 1;
    fa[u][0] = fas;
    if (mak1[fas] != u) maxv[u][0] = max(v[u], max1[fas]);
    else maxv[u][0] = max(v[u], max2[fas]);
    
    for (int j = 1; j < K; j ++ )
    {
        fa[u][j] = fa[fa[u][j - 1]][j - 1];
        maxv[u][j] = max(maxv[u][j - 1], maxv[fa[u][j - 1]][j - 1]);
    }

    for (int i = h[u]; i != -1; i = ne[i])
    {
        int j = e[i];
        if (j == fas) continue;
        if (mak1[u] != j) sum[j] = max(sum[u], max1[u]) - 2 * w[i];
        else sum[j] = max(sum[u], max2[u]) - 2 * w[i];
        dfs2(j, u);
    }
}

int lca(int a, int b)
{
    if (depth[a] < depth[b]) swap(a, b);

    for (int j = K - 1; j >= 0; j -- )
        if (depth[fa[a][j]] >= depth[b])
        {
            if (depth[fa[a][j]] == depth[b]) kk = a;
            res = max(res, maxv[a][j]);
            a = fa[a][j];
            
        }
        
    if (a == b) 
    {
        
        res = max(res, v[a]);
        return a;
    }
    
    for (int j = K - 1; j >= 0; j -- )
        if (fa[a][j] != fa[b][j])
        {
            res = max({res, maxv[a][j], maxv[b][j]});
            a = fa[a][j];
            b = fa[b][j];
        }
    kk = a;
    int u = fa[a][0];
    if (a != mak1[u] && b != mak1[u]) res = max(res, max1[u]);
    else if (a != mak2[u] && b != mak2[u]) res = max(res, max2[u]);
    else res = max(res, max3[u]);
    
    return u;
}

int main()
{
//  	freopen("speaker4.in", "r", stdin);
//  	freopen("speaker4.out", "w", stdout);
    cin >> n >> m;
    for (int i = 1; i <= n; i ++ ) scanf("%lld", &v[i]);
    
    memset(h, -1, sizeof h);
    for (int i = 1; i < n; i ++ )
    {
        int a, b, c;
        scanf("%d%d%d", &a, &b, &c);
        add(a, b, c);
        add(b, a, c);
        // cout << a << ' ' << b << ' ' <<  c << endl;
    }
    
    dfs1(1, 0);
    dfs2(1, 0);
    
    while (m -- )
    {
        int a, b;
        res = 0;
        scanf("%d%d", &a, &b);
        int p = lca(a, b);
        // printf("%d %d %d\n", a, b, p);     
        LL ans = max({sum[p], res});
        if (p == a) 
        {
            if (mak1[a] != kk) ans = max({ans, max1[b], max1[a]});
            else ans = max({ans, max1[b], max2[a]});
        }
        else if (p == b)
        {
            if (mak1[b] != kk) ans = max({ans, max1[a], max1[b]});
            else ans = max({ans, max1[a], max2[b]});
        }
        else 
        {
            ans = max({ans, max1[a], max1[b]});
        }
        printf("%lld\n", ans - (dist[a] + dist[b] - 2 * dist[p]) + v[a] + v[b]);
        // printf("%lld %lld %lld %lld %lld\n", sum[p], max1[a], max1[b], res, ans);
        
    }
    
    return 0;
}