HDU 2586 倍增 / Tarjan LCA

neweryyy發表於2020-12-15
題意

傳送門 HDU 2586

題解

L C A LCA LCA 模板題。

倍增 LCA

基於倍增的 L C A LCA LCA 時間複雜度 O ( ( n + m ) log ⁡ n ) O\big((n+m)\log n\big) O((n+m)logn)

#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
#define maxn 40005
#define maxlg 16
int T, N, Q, lg[maxn], fa[maxlg][maxn], dep[maxn], dis[maxn];
int head[maxn << 1], nxt[maxn << 1], to[maxn << 1], cost[maxn << 1], tot;

void add(int x, int y, int z)
{
    to[++tot] = y, cost[tot] = z, nxt[tot] = head[x], head[x] = tot;
}

void dfs(int x, int f, int d, int w)
{
    fa[0][x] = f, dep[x] = d, dis[x] = w;
    for (int k = 1; k <= lg[d]; ++k)
        fa[k][x] = fa[k - 1][fa[k - 1][x]];
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = to[i], z = cost[i];
        if (y != f)
            dfs(y, x, d + 1, w + z);
    }
}

int lca(int x, int y)
{
    if (dep[x] < dep[y])
        swap(x, y);
    while (dep[x] - dep[y] > 0)
        x = fa[lg[dep[x] - dep[y]]][x];
    if (x == y)
        return x;
    for (int i = lg[dep[x]]; i >= 0; --i)
        if (fa[i][x] != fa[i][y])
            x = fa[i][x], y = fa[i][y];
    return fa[0][x];
}

int main()
{
    scanf("%d", &T);
    lg[0] = -1;
    for (int i = 1; i < maxn; ++i)
        lg[i] = lg[i - 1] + (1 << (lg[i - 1] + 1) == i);
    while (T--)
    {
        scanf("%d%d", &N, &Q);
        tot = 0;
        memset(head, 0, sizeof(head));
        for (int i = 1; i < N; ++i)
        {
            int x, y, z;
            scanf("%d%d%d", &x, &y, &z);
            add(x, y, z);
            add(y, x, z);
        }
        dfs(1, 0, 0, 0);
        while (Q--)
        {
            int x, y;
            scanf("%d%d", &x, &y);
            printf("%d\n", dis[x] + dis[y] - 2 * dis[lca(x, y)]);
        }
    }
    return 0;
}
Tarjan LCA

L C A LCA LCA T a r j a n Tarjan Tarjan 是一個離線演算法,演算法基本思想是當節點 x x x 遞迴,若節點 y y y 已經回溯,那麼 y y y 向上一直走到一個開始遞迴尚未回溯的節點,即 L C A ( x , y ) LCA(x,y) LCA(x,y) y y y 回溯時使用並查集維護其指向的 L C A LCA LCA。需要特殊處理 x = y x=y x=y 的情況,因為節點不可能同時處於遞迴與結束回溯兩種狀態。時間複雜度為 O ( n + m ) O(n+m) O(n+m)

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#define maxn 40005
#define maxm 205
int T, N, Q;
int head[maxn << 1], nxt[maxn << 1], to[maxn << 1], cost[maxn << 1], tot;
int par[maxn], fg[maxn], dis[maxn], res[maxm];
vector<int> q_x[maxn], q_id[maxn];

void add(int x, int y, int z) { to[++tot] = y, cost[tot] = z, nxt[tot] = head[x], head[x] = tot; }

int find(int x) { return par[x] == x ? x : (par[x] = find(par[x])); }

void add_q(int x, int y, int i) { q_x[x].push_back(y), q_id[x].push_back(i); }

void tarjan(int x, int w)
{
    fg[x] = 1, dis[x] = w;
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = to[i], z = cost[i];
        if (!fg[y])
        {
            tarjan(y, w + z);
            par[y] = x;
        }
    }
    for (int i = 0; i < (int)q_x[x].size(); ++i)
    {
        int y = q_x[x][i], id = q_id[x][i];
        if (fg[y] == 2)
        {
            int lca = find(y);
            res[id] = dis[x] + dis[y] - 2 * dis[lca];
        }
    }
    fg[x] = 2;
}

int main()
{
    scanf("%d", &T);
    while (T--)
    {
        scanf("%d%d", &N, &Q);
        tot = 0;
        for (int i = 1; i <= N; ++i)
            head[i] = 0, fg[i] = 0, par[i] = i, q_x[i].clear(), q_id[i].clear();
        for (int i = 1; i < N; ++i)
        {
            int x, y, z;
            scanf("%d%d%d", &x, &y, &z);
            add(x, y, z), add(y, x, z);
        }
        for (int i = 1; i <= Q; ++i)
        {
            int x, y;
            scanf("%d%d", &x, &y);
            if (x == y)
                res[i] = 0;
            else
                add_q(x, y, i), add_q(y, x, i);
        }
        tarjan(1, 0);
        for (int i = 1; i <= Q; ++i)
            printf("%d\n", res[i]);
    }
    return 0;
}