LCA + 樹上倍增

gebeng發表於2024-04-06

LCA + 樹上倍增

一、例題引入

題目:

2846. 邊權重均等查詢

現有一棵由 n 個節點組成的無向樹,節點按從 0n - 1 編號。給你一個整數 n 和一個長度為 n - 1 的二維整數陣列 edges ,其中 edges[i] = [ui, vi, wi] 表示樹中存在一條位於節點 ui 和節點 vi 之間、權重為 wi 的邊。

另給你一個長度為 m 的二維整數陣列 queries ,其中 queries[i] = [ai, bi] 。對於每條查詢,請你找出使從 aibi 路徑上每條邊的權重相等所需的 最小操作次數 。在一次操作中,你可以選擇樹上的任意一條邊,並將其權重更改為任意值。

注意:

  • 查詢之間 相互獨立 的,這意味著每條新的查詢時,樹都會回到 初始狀態
  • aibi的路徑是一個由 不同 節點組成的序列,從節點 ai 開始,到節點 bi 結束,且序列中相鄰的兩個節點在樹中共享一條邊。

返回一個長度為 m 的陣列 answer ,其中 answer[i] 是第 i 條查詢的答案。

示例:

query[i] = [2,6],將2-3這條邊改成2,所以ans[i] = 1

img

思路:

  • ①求2到6的距離 d = 4;
  • ②求2到6邊權出現次數最多的次數 cnt_max = 3;
  • ③答案即為:d - cnt_max = 1

二、對症下藥

①怎麼快速求出一棵樹上任意兩個點的距離呢?

d(a-b) = d(a-lca) + d(b-lca) = (d(a-root) - d(lca-root)) + (d(b-root) - d(lca-root)) = d(a) + d(b) - 2 x d(lca)

只要求出最近公共祖先lca後,就可以根據如上公式求出任意兩點的距離。

②怎麼求公共祖先呢?

1.預處理pa陣列

pa[x][0] = y代表 x 的父節點是y.

pa[x][1] = y 代表 x 的父節點的父節點是y.

pa[x][2] = y代表 x 的爺節點的爺節點是y.

依次類推..........................................................

pa[x][i + 1] = pa[pa[x][i]][i]

// 設 m 為最大編號的二進位制位數,pa陣列初始化為-1
for (int i = 0; i < m - 1; i++) {
    for (int x = 0; x < n; x++) {
        int p = pa[x][i];
        if (p != -1)  pa[x][i + 1] = pa[p][i];
    }
}

2.二進位制倍增

xy的最近公共祖先為lca,根節點為root

  • 首先,使得xy同一層

    • 如果在同一層時x = y,那麼lca = x = y
  • xy按照i從大往小跳 \(2^i\)得到 fx,fy(類比數的二進位制表示)

    • 如果 fx = fy,就說明跳得太遠了,(超過了lca或者就是lca)下一次就跳得一些
    • 如果fx != fy,就說明在lca之下,那麼更新x = fx,y = fy
  • 最後,得到的節點一定是lca的兒子節點

    • lca = pa[x][0]
if (depth[x] > depth[y]) swap(x, y);
// 讓 y 和 x 在同一深度
for (int k = depth[y] - depth[x]; k; k &= k - 1) {
    int i = __builtin_ctz(k);
    int p = pa[y][i];
    y = p;
}
if (y != x) {
    // x 和 y 同時上跳 2^i 步
    for (int i = m - 1; i >= 0; i--) {
        int fx = pa[x][i], fy = pa[y][i];
        if (fx != fy) {
            x = fx;
            y = fy; 
        }
    }
    x = pa[x][0];
}
lca = x;

③怎麼求邊權出現次數最多的那條邊的次數呢?

1.定義cnt[x][i][w]陣列

cnt[x][0][w] = 1 代表 xx父節點之間的路徑的權值為w的邊個數為1.

cnt[x][1][w] = 2 代表 xx爺節點之間的路徑的權值為w的邊個數為2.

cnt[x][i][w] = cnt代表 xx\(2^i\) 後的節點的路徑的權值為w的邊個數為cnt.

只需在求LCA的過程中維護與更新cnt即可!

三、程式碼展示

class Solution {
public:
    vector<int> minOperationsQueries(int n, vector<vector<int>> &edges, vector<vector<int>> &queries) {
        vector<vector<pair<int, int>>> g(n);
        for (auto &e: edges) {
            int x = e[0], y = e[1], w = e[2] - 1;
            g[x].emplace_back(y, w);
            g[y].emplace_back(x, w);
        }
        int m = __lg(n) + 1; // n 的二進位制長度
        vector<vector<int>> pa(n, vector<int>(m, -1));
        vector<vector<array<int, 26>>> cnt(n, vector<array<int, 26>>(m));
        vector<int> depth(n);
        function<void(int, int)> dfs = [&](int x, int fa) {
            pa[x][0] = fa;
            for (auto [y, w]: g[x]) {
                if (y != fa) {
                    cnt[y][0][w] = 1;
                    depth[y] = depth[x] + 1;
                    dfs(y, x);
                }
            }
        };
        dfs(0, -1);
        for (int i = 0; i < m - 1; i++) {
            for (int x = 0; x < n; x++) {
                int p = pa[x][i];
                if (p != -1) {
                    pa[x][i + 1] = pa[p][i];
                    for (int j = 0; j < 26; ++j) {
                        cnt[x][i + 1][j] = cnt[x][i][j] + cnt[p][i][j];
                    }
                }
            }
        }
        vector<int> ans;
        for (auto &q: queries) {
            int x = q[0], y = q[1];
            int path_len = depth[x] + depth[y]; // 最後減去 depth[lca] * 2
            int cw[26]{};
            if (depth[x] > depth[y]) {
                swap(x, y);
            }
            // 讓 y 和 x 在同一深度
            for (int k = depth[y] - depth[x]; k; k &= k - 1) {
                int i = __builtin_ctz(k);
                int p = pa[y][i];
                for (int j = 0; j < 26; ++j) {
                    cw[j] += cnt[y][i][j];
                }
                y = p;
            }
            if (y != x) {
                for (int i = m - 1; i >= 0; i--) {
                    int fx = pa[x][i], fy = pa[y][i];
                    if (fx != fy) {
                        for (int j = 0; j < 26; j++) {
                            cw[j] += cnt[x][i][j] + cnt[y][i][j];
                        }
                        x = fx;y = fy; // x 和 y 同時上跳 2^i 步
                    }
                }
                for (int j = 0; j < 26; j++) {
                    cw[j] += cnt[x][0][j] + cnt[y][0][j];
                }
                x = pa[x][0];
            }
            int lca = x;
            path_len -= depth[lca] * 2;
            ans.push_back(path_len - *max_element(cw, cw + 26));
        }
        return ans;
    }
};

四、實戰演練

給定一棵包含 n個節點的有根無向樹,節點編號互不相同,但不一定是 1∼n。

有 m個詢問,每個詢問給出了一對節點的編號 x 和 y,詢問 x與 y 的祖孫關係。

輸入格式

輸入第一行包括一個整數 表示節點個數;

接下來 n行每行一對整數 a 和 b,表示 a 和 b 之間有一條無向邊。如果 b是 −1−1,那麼 a 就是樹的根;

第 n+2 行是一個整數 m表示詢問個數;

接下來 m 行,每行兩個不同的正整數 x和 y,表示一個詢問。

輸出格式

對於每一個詢問,若 x是 y的祖先則輸出 1,若 y是 x的祖先則輸出 2,否則輸出 0。

程式碼撰寫

#include<bits/stdc++.h>
using namespace std;
const int N = 40010, M = 2 * N;
int h[N], e[M], ne[M], idx;
int depth[N], pa[N][20],root;
void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}

//預處理每個結點的深度,以及結點的父結點的編號
void dfs(int u, int fa)
{
    pa[u][0] = fa;
    for(int i = h[u]; ~i; i = ne[i])
    {
        int v = e[i];
        if(v != fa){
            depth[v] = depth[u] + 1;
            dfs(v,u);
        }
    }
}
int get_lca(int x,int y){
    if(depth[x] > depth[y]) swap(x,y);
    for(int k = depth[y] - depth[x];k;k &= k - 1){
        int i = __builtin_ctz(k);
        y = pa[y][i];
    }
    if(x == y) return y;
    for(int i = 15;i >= 0;--i){
        int fx = pa[x][i],fy = pa[y][i];
        if(fx != fy){
            x = fx;y = fy;
        }
    }
    return pa[x][0];
}
int main()
{
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    memset(h, -1, sizeof(h));memset(pa,-1,sizeof(pa));
    int t;cin >> t;
    while(t--){
        int a,b;cin >> a >> b;
        if(b == -1) root = a;
        else {add(a,b);add(b,a);}
    }
    dfs(root,-1);
    for(int i = 0;i < 15;++i){
        for(int u = 0;u < N;++u){
            int p = pa[u][i];
            if(p != -1) pa[u][i + 1] = pa[p][i];
        }
    }
    cin >> t;
    while(t--){
        int a,b;cin >> a >> b;
        int lca = get_lca(a,b);
        if     (lca == a) cout << '1' << '\n';
        else if(lca == b) cout << '2' << '\n';
        else              cout << '0' << '\n';
    }
    return 0;
}

相關文章