樹鏈剖分
樹鏈剖分的基本思想
透過將樹分割成鏈的形式,從而把樹形變為線性結構,減少處理難度。
樹鏈剖分(樹剖/鏈剖)有多種形式,如 重鏈剖分,長鏈剖分 和用於 Link/cut Tree 的剖分(有時被稱作「實鏈剖分」),大多數情況下(沒有特別說明時),「樹鏈剖分」都指「重鏈剖分」。
樹鏈有如下幾個特徵:
- 一棵樹上的任意一條鏈的長度不超過 \(\log_2 n\),且一條鏈上的各個節點深度互不相同(相對於根節點而言)。
- 透過特殊的遍歷方式,樹鏈剖分可以保證一條鏈上的 DFS 序連續,從而更加方便地使用線段樹或樹狀陣列維護樹上的區間資訊。
重鏈剖分
重鏈剖分的基本定義
重鏈剖分,顧名思義,是一種透過子節點大小進行剖分的形式,我們給出以下定義:
-
重子節點:一個非葉子節點中子樹最大的子節點,如果存在多個則取其一。
-
輕子節點:除了重子節點以外的所有子節點。
-
重邊:連線任意兩個重兒子的邊。
-
輕邊:除了重邊的其他所有樹邊。
-
重鏈:若干條重邊首尾相連而成的一條鏈。
我們將落單的葉子節點本身看做重子節點,不難發現,整棵樹就被剖分成了一條條重鏈。
需要注意的是,每一條重鏈都以輕子節點為起點。
實現
樹鏈剖分的處理透過兩次 DFS 遍歷完成。
對於第一次 DFS,我們求出如下數值:
-
任意節點到達根的距離(即其深度):
depth[]
-
任意節點的父親節點(根節點預設為 \(0\)):
fa[]
-
任意節點子樹的大小(包括其本身):
sz[]
-
任意節點的重子節點(沒有則為 \(0\)):
hson[]
void dfs1(int ver, int pre, int deep)
{
depth[ver] = deep, fa[ver] = pre, sz[ver] = 1;
int maxn = -1;
for (int i = h[ver]; ~i; i = ne[i])
{
int j = e[i];
if (j == pre) continue;
dfs1(j, ver, deep + 1);
sz[ver] += sz[j];
if (maxn == -1 || maxn < sz[j]) maxn = sz[j], hson[ver] = j;
// 更新重子節點
}
}
對於第二次 DFS,我們求出如下數值:
-
遍歷時的各個節點的 dfs 序:
dfn[]
-
每個節點所屬重鏈的最頂端節點:
top[]
-
dfs 序對應的節點編號:
id[]
,有 \(id(dfn(x)) = x\) -
每個節點在 dfs 序上的對應權值:
val[]
void dfs2(int ver, int topf)
{
dfn[ver] = ++ timestamp, val[timestamp] = a[ver], top[ver] = topf;
if (!hson[ver]) return;
dfs2(hson[ver], topf); // 先遍歷重子節點
for (int i = h[ver]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[ver] || j == hson[ver]) continue;
dfs2(j, j); // 再遍歷輕子節點
}
}
之所以要先遍歷重子節點,是因為我們要保證重鏈上的 dfs 序連續,這樣才可以進行區間操作,按照 \(dfn\) 排序後的序列即為剖分後的鏈。
重鏈剖分的性質
-
樹上每個節點都屬於且僅屬於一條重鏈。
-
所有的重鏈將整棵樹 完全剖分。
-
當我們向下經過一條 輕邊 時,所在子樹的大小至少會除以二,保證了複雜度的正確性。
常見應用
維護路徑權值和
選取左右端點所在樹中深度更大的節點,維護它到所在重鏈頂端的區間資訊,之後不斷上跳,知道它和另一端點在同一鏈上,維護兩點之間的資訊。使用線段樹或者樹狀陣列等資料結構,即可在 \(O(\log^2 n)\) 的時間內單次維護查詢。
void modify_range(int x, int y, int k)
{
while (top[x] != top[y])
{
if (depth[top[x]] < depth[top[y]]) swap(x, y);
SGT.modify(1, dfn[top[x]], dfn[x], k);
x = fa[top[x]];
}
if (depth[x] > depth[y]) swap(x, y);
SGT.modify(1, dfn[x], dfn[y], k);
}
int query_range(int x, int y)
{
int res = 0;
while (top[x] != top[y])
{
if (depth[top[x]] < depth[top[y]]) swap(x, y);
res = (res + SGT.query(1, dfn[top[x]], dfn[x])) % mod;
x = fa[top[x]];
}
if (depth[x] > depth[y]) swap(x, y);
res = (res + SGT.query(1, dfn[x], dfn[y])) % mod;
return res;
}
維護子樹資訊
思路相似,但更加簡單,經過 dfn 重新劃分後,一顆子樹的 dfn 序列一定在 \([dfn[x], dfn[x] +sz[x] - 1]\) 之間,單次維護即可,時間複雜度 \(O(\log n)\)。
void modify_subtree(int x, int k)
{
SGT.modify(1, dfn[x], dfn[x] + sz[x] - 1, k);
}
int query_subtree(int x)
{
return SGT.query(1, dfn[x], dfn[x] + sz[x] - 1);
}
求 LCA
與倍增求法相似,但常數更小。
每次選取重鏈頂端節點深度更大的節點上跳,知道兩者在同一重鏈上,此時深度較小者為兩節點 LCA。
int lca(int a, int b)
{
while (top[a] != top[b])
{
if (depth[top[a]] > depth[top[b]]) a = fa[top[a]];
else b = fa[top[b]];
}
return depth[a] < depth[b] ? a : b;
}
例題 & Code
P3384 【模板】重鏈剖分/樹鏈剖分
// Problem: P3384 【模板】重鏈剖分/樹鏈剖分
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3384
// Memory Limit: 128 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
using namespace std;
// #define int long long
#define DEBUG
#define lc u << 1
#define rc u << 1 | 1
#define File(a) freopen(a".in", "r", stdin); freopen(a".out", "w", stdout)
typedef long long LL;
typedef pair<int, int> PII;
const int N = 100010, M = N << 1;
const int INF = 0x3f3f3f3f;
int n, m, root, mod;
int h[N], e[M], ne[M], idx;
int a[N], val[N];
int depth[N], fa[N], sz[N], hson[N];
int dfn[N], timestamp;
int top[N], id[N];
struct Tree
{
struct Node
{
int l, r, sum, tag;
inline int len() {return r - l + 1; }
} tr[N << 2];
void pushup(int u)
{
tr[u].sum = (tr[lc].sum + tr[rc].sum) % mod;
}
void build(int u, int l, int r)
{
tr[u].l = l, tr[u].r = r;
if (l == r) return tr[u].sum = val[l], void(0);
int mid = l + r >> 1;
build(lc, l, mid), build(rc, mid + 1, r);
pushup(u);
}
void pushdown(int u)
{
if (!tr[u].tag) return;
tr[lc].sum = (tr[lc].sum + tr[u].tag * tr[lc].len()) % mod;
tr[rc].sum = (tr[rc].sum + tr[u].tag * tr[rc].len()) % mod;
tr[lc].tag += tr[u].tag, tr[rc].tag += tr[u].tag;
tr[u].tag = 0;
}
void modify(int u, int l, int r, int k)
{
if (l <= tr[u].l && tr[u].r <= r)
{
tr[u].sum = (tr[u].sum + tr[u].len() * k) % mod;
tr[u].tag += k;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(lc, l, r, k);
if (r > mid) modify(rc, l, r, k);
pushup(u);
}
int query(int u, int l, int r)
{
if (l <= tr[u].l && tr[u].r <= r)
return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int res = 0;
if (l <= mid) res = (res + query(lc, l, r)) % mod;
if (r > mid) res = (res + query(rc, l, r)) % mod;
return res;
}
} SGT;
inline void add(int a, int b)
{
e[++ idx] = b, ne[idx] = h[a], h[a] = idx;
}
void dfs1(int ver, int pre, int deep)
{
depth[ver] = deep, fa[ver] = pre, sz[ver] = 1;
int maxn = -1;
for (int i = h[ver]; ~i; i = ne[i])
{
int j = e[i];
if (j == pre) continue;
dfs1(j, ver, deep + 1);
sz[ver] += sz[j];
if (maxn == -1 || maxn < sz[j]) maxn = sz[j], hson[ver] = j;
}
}
void dfs2(int ver, int topf)
{
dfn[ver] = ++ timestamp, val[timestamp] = a[ver], top[ver] = topf;
if (!hson[ver]) return;
dfs2(hson[ver], topf);
for (int i = h[ver]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[ver] || j == hson[ver]) continue;
dfs2(j, j);
}
}
void modify_range(int x, int y, int k)
{
while (top[x] != top[y])
{
if (depth[top[x]] < depth[top[y]]) swap(x, y);
SGT.modify(1, dfn[top[x]], dfn[x], k);
x = fa[top[x]];
}
if (depth[x] > depth[y]) swap(x, y);
SGT.modify(1, dfn[x], dfn[y], k);
}
int query_range(int x, int y)
{
int res = 0;
while (top[x] != top[y])
{
if (depth[top[x]] < depth[top[y]]) swap(x, y);
res = (res + SGT.query(1, dfn[top[x]], dfn[x])) % mod;
x = fa[top[x]];
}
if (depth[x] > depth[y]) swap(x, y);
res = (res + SGT.query(1, dfn[x], dfn[y])) % mod;
return res;
}
void modify_subtree(int x, int k)
{
SGT.modify(1, dfn[x], dfn[x] + sz[x] - 1, k);
}
int query_subtree(int x)
{
return SGT.query(1, dfn[x], dfn[x] + sz[x] - 1);
}
signed main()
{
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
memset(h, -1, sizeof h);
cin >> n >> m >> root >> mod;
for (int i = 1; i <= n; i ++) cin >> a[i];
for (int i = 1; i < n; i ++)
{
int u, v; cin >> u >> v;
add(u, v), add(v, u);
}
dfs1(root, 0, 1);
dfs2(root, root);
SGT.build(1, 1, n);
while (m --)
{
int opt; cin >> opt;
if (opt == 1)
{
int x, y, z; cin >> x >> y >> z;
modify_range(x, y, z);
}
else if (opt == 2)
{
int x, y; cin >> x >> y;
cout << query_range(x, y) % mod << '\n';
}
else if (opt == 3)
{
int x, y; cin >> x >> y;
modify_subtree(x, y);
}
else
{
int x; cin >> x;
cout << query_subtree(x) % mod << '\n';
}
}
return 0;
}
Reference
樹鏈剖分 - OI Wiki
P3384 【模板】重鏈剖分/樹鏈剖分 題解