LCA + 樹上倍增
一、例題引入
題目:
2846. 邊權重均等查詢
現有一棵由 n
個節點組成的無向樹,節點按從 0
到 n - 1
編號。給你一個整數 n
和一個長度為 n - 1
的二維整數陣列 edges
,其中 edges[i] = [ui, vi, wi]
表示樹中存在一條位於節點 ui
和節點 vi
之間、權重為 wi
的邊。
另給你一個長度為 m
的二維整數陣列 queries
,其中 queries[i] = [ai, bi]
。對於每條查詢,請你找出使從 ai
到 bi
路徑上每條邊的權重相等所需的 最小操作次數 。在一次操作中,你可以選擇樹上的任意一條邊,並將其權重更改為任意值。
注意:
- 查詢之間 相互獨立 的,這意味著每條新的查詢時,樹都會回到 初始狀態 。
- 從
ai
到bi
的路徑是一個由 不同 節點組成的序列,從節點ai
開始,到節點bi
結束,且序列中相鄰的兩個節點在樹中共享一條邊。
返回一個長度為 m
的陣列 answer
,其中 answer[i]
是第 i
條查詢的答案。
示例:
query[i] = [2,6]
,將2-3
這條邊改成2
,所以ans[i] = 1
思路:
- ①求2到6的距離
d = 4
; - ②求2到6邊權出現次數最多的次數
cnt_max = 3
; - ③答案即為:
d - cnt_max = 1
二、對症下藥
①怎麼快速求出一棵樹上任意兩個點的距離呢?
d(a-b) = d(a-lca) + d(b-lca) = (d(a-root) - d(lca-root)) + (d(b-root) - d(lca-root)) = d(a) + d(b) - 2 x d(lca)
只要求出最近公共祖先lca
後,就可以根據如上公式求出任意兩點的距離。
②怎麼求公共祖先呢?
1.預處理pa陣列
pa[x][0] = y
代表 x 的父節點是y.
pa[x][1] = y
代表 x 的父節點的父節點是y.
pa[x][2] = y
代表 x 的爺節點的爺節點是y.
依次類推..........................................................
pa[x][i + 1] = pa[pa[x][i]][i]
// 設 m 為最大編號的二進位制位數,pa陣列初始化為-1
for (int i = 0; i < m - 1; i++) {
for (int x = 0; x < n; x++) {
int p = pa[x][i];
if (p != -1) pa[x][i + 1] = pa[p][i];
}
}
2.二進位制倍增
設x
與y
的最近公共祖先為lca
,根節點為root
-
首先,使得
x
與y
在同一層
- 如果在同一層時
x = y
,那麼lca = x = y
- 如果在同一層時
-
將
x
與y
按照i
從大往小跳 \(2^i\)得到fx
,fy
(類比數的二進位制表示)- 如果
fx = fy
,就說明跳得太遠
了,(超過了lca或者就是lca
)下一次就跳得近
一些 - 如果
fx != fy
,就說明在lca
之下,那麼更新x = fx
,y = fy
- 如果
-
最後,得到的節點一定是
lca
的兒子節點lca = pa[x][0]
if (depth[x] > depth[y]) swap(x, y);
// 讓 y 和 x 在同一深度
for (int k = depth[y] - depth[x]; k; k &= k - 1) {
int i = __builtin_ctz(k);
int p = pa[y][i];
y = p;
}
if (y != x) {
// x 和 y 同時上跳 2^i 步
for (int i = m - 1; i >= 0; i--) {
int fx = pa[x][i], fy = pa[y][i];
if (fx != fy) {
x = fx;
y = fy;
}
}
x = pa[x][0];
}
lca = x;
③怎麼求邊權出現次數最多的那條邊的次數呢?
1.定義cnt[x][i][w]
陣列
cnt[x][0][w] = 1
代表 x
與x
父節點之間的路徑的權值為w
的邊個數為1
.
cnt[x][1][w] = 2
代表 x
與x
爺節點之間的路徑的權值為w
的邊個數為2
.
cnt[x][i][w] = cnt
代表 x
與x
跳\(2^i\) 後的節點的路徑的權值為w
的邊個數為cnt
.
只需在求LCA
的過程中維護與更新cnt
即可!
三、程式碼展示
class Solution {
public:
vector<int> minOperationsQueries(int n, vector<vector<int>> &edges, vector<vector<int>> &queries) {
vector<vector<pair<int, int>>> g(n);
for (auto &e: edges) {
int x = e[0], y = e[1], w = e[2] - 1;
g[x].emplace_back(y, w);
g[y].emplace_back(x, w);
}
int m = __lg(n) + 1; // n 的二進位制長度
vector<vector<int>> pa(n, vector<int>(m, -1));
vector<vector<array<int, 26>>> cnt(n, vector<array<int, 26>>(m));
vector<int> depth(n);
function<void(int, int)> dfs = [&](int x, int fa) {
pa[x][0] = fa;
for (auto [y, w]: g[x]) {
if (y != fa) {
cnt[y][0][w] = 1;
depth[y] = depth[x] + 1;
dfs(y, x);
}
}
};
dfs(0, -1);
for (int i = 0; i < m - 1; i++) {
for (int x = 0; x < n; x++) {
int p = pa[x][i];
if (p != -1) {
pa[x][i + 1] = pa[p][i];
for (int j = 0; j < 26; ++j) {
cnt[x][i + 1][j] = cnt[x][i][j] + cnt[p][i][j];
}
}
}
}
vector<int> ans;
for (auto &q: queries) {
int x = q[0], y = q[1];
int path_len = depth[x] + depth[y]; // 最後減去 depth[lca] * 2
int cw[26]{};
if (depth[x] > depth[y]) {
swap(x, y);
}
// 讓 y 和 x 在同一深度
for (int k = depth[y] - depth[x]; k; k &= k - 1) {
int i = __builtin_ctz(k);
int p = pa[y][i];
for (int j = 0; j < 26; ++j) {
cw[j] += cnt[y][i][j];
}
y = p;
}
if (y != x) {
for (int i = m - 1; i >= 0; i--) {
int fx = pa[x][i], fy = pa[y][i];
if (fx != fy) {
for (int j = 0; j < 26; j++) {
cw[j] += cnt[x][i][j] + cnt[y][i][j];
}
x = fx;y = fy; // x 和 y 同時上跳 2^i 步
}
}
for (int j = 0; j < 26; j++) {
cw[j] += cnt[x][0][j] + cnt[y][0][j];
}
x = pa[x][0];
}
int lca = x;
path_len -= depth[lca] * 2;
ans.push_back(path_len - *max_element(cw, cw + 26));
}
return ans;
}
};
四、實戰演練
給定一棵包含 n個節點的有根無向樹,節點編號互不相同,但不一定是 1∼n。
有 m個詢問,每個詢問給出了一對節點的編號 x 和 y,詢問 x與 y 的祖孫關係。
輸入格式
輸入第一行包括一個整數 表示節點個數;
接下來 n行每行一對整數 a 和 b,表示 a 和 b 之間有一條無向邊。如果 b是 −1−1,那麼 a 就是樹的根;
第 n+2 行是一個整數 m表示詢問個數;
接下來 m 行,每行兩個不同的正整數 x和 y,表示一個詢問。
輸出格式
對於每一個詢問,若 x是 y的祖先則輸出 1,若 y是 x的祖先則輸出 2,否則輸出 0。
程式碼撰寫
#include<bits/stdc++.h>
using namespace std;
const int N = 40010, M = 2 * N;
int h[N], e[M], ne[M], idx;
int depth[N], pa[N][20],root;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
//預處理每個結點的深度,以及結點的父結點的編號
void dfs(int u, int fa)
{
pa[u][0] = fa;
for(int i = h[u]; ~i; i = ne[i])
{
int v = e[i];
if(v != fa){
depth[v] = depth[u] + 1;
dfs(v,u);
}
}
}
int get_lca(int x,int y){
if(depth[x] > depth[y]) swap(x,y);
for(int k = depth[y] - depth[x];k;k &= k - 1){
int i = __builtin_ctz(k);
y = pa[y][i];
}
if(x == y) return y;
for(int i = 15;i >= 0;--i){
int fx = pa[x][i],fy = pa[y][i];
if(fx != fy){
x = fx;y = fy;
}
}
return pa[x][0];
}
int main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
memset(h, -1, sizeof(h));memset(pa,-1,sizeof(pa));
int t;cin >> t;
while(t--){
int a,b;cin >> a >> b;
if(b == -1) root = a;
else {add(a,b);add(b,a);}
}
dfs(root,-1);
for(int i = 0;i < 15;++i){
for(int u = 0;u < N;++u){
int p = pa[u][i];
if(p != -1) pa[u][i + 1] = pa[p][i];
}
}
cin >> t;
while(t--){
int a,b;cin >> a >> b;
int lca = get_lca(a,b);
if (lca == a) cout << '1' << '\n';
else if(lca == b) cout << '2' << '\n';
else cout << '0' << '\n';
}
return 0;
}