「學習筆記」樹鏈剖分

yi_fan0305發表於2023-09-21

樹鏈剖分用於將樹分割成若干條鏈的形式,以維護樹上路徑的資訊。
具體來說,將整棵樹剖分為若干條鏈,使它組合成線性結構,然後用其他的資料結構維護資訊。

樹鏈剖分有很多種形式,本文要講的是其中的輕重鏈剖分。

樹鏈剖分本質上就是把鏈從樹上砍下來,然後放到樹狀陣列或線段樹上來維護。

輕重鏈剖分

我們給出一些定義:

定義 重子節點 表示其子節點中 子樹最大 的子結點。如果有多個子樹最大的子結點,取其一。如果沒有子節點,就無重子節點。

定義 輕子節點 表示剩餘的所有子結點。

從這個結點到重子節點的邊為 重邊

到其他輕子節點的邊為 輕邊

若干條首尾銜接的重邊構成 重鏈

把落單的結點也當作重鏈,那麼整棵樹就被剖分成若干條重鏈。

如圖:

image

圖片來自 \(\texttt{OI-Wiki}\)

實現

要進行樹鏈剖分,得先有一棵樹,透過一個 dfs 來得到所有的子樹大小、重兒子以及每個結點的深度。

void dfs(int u, int fat) {
    siz[u] = 1;
    fa[u] = fat;
    dep[u] = dep[fat] + 1;
    for (int v : e[u]) {
        if (v == fat)   continue ;
        dfs(v, u);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) {
            son[u] = v;
        }
    }
}

得到這些資訊後,我們就要把樹分成好多的鏈,透過另一個 dfs 來確定好鏈頂,有時我們要用線段樹或樹狀陣列來維護這些鏈,所以可能還要維護線上段樹上的位置(dfs 序)等資訊。

void getpos(int u, int top) {
    dfn[u] = ++ tim; // 線上段樹上的位置 (dfs 序)
    pos[tim] = u; // 線段樹這個位置所代表的節點
    tp[u] = top; // 鏈頂
    if (!son[u])    return ;
    getpos(son[u], top); // 優先跑重兒子
    for (int v : e[u]) {
        if (v == fa[u] || v == son[u])  continue ;
        getpos(v, v); // 輕兒子再單獨分鏈
    }
}

到這裡,處理部分就完成了,接下來就是根據題目來對這些鏈進行處理了,可以使用線段樹、樹狀陣列等資料結構來維護鏈的資訊,也可以利用跳鏈頂的優秀複雜度求 LCA 等。

例題

P3384 【模板】重鏈剖分/樹鏈剖分 - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)

模板題。

要求我們對鏈上節點的權值進行修改、求和,以及對子樹內的所有節點進行修改、求和,我們使用線段樹來進行維護。

對於鏈上的操作,由於我們已經把鏈都擺在了線段樹上,所以只需要對線段樹進行區間操作即可。對於子樹的操作,根據 dfs 序的性質,一棵子樹在 dfs 序上的範圍是 \([dfn_{rt}, dfn_{rt} + siz_{rt} - 1]\),對這個區間進行區間操作即可。

void Modify(int x, int y, ll z) {
    while (tp[x] != tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) {
            swap(x, y);
        }
        modify(1, 1, n, dfn[tp[x]], dfn[x], z);
        x = fa[tp[x]];
    }
    if (dep[x] > dep[y]) {
        swap(x, y);
    }
    modify(1, 1, n, dfn[x], dfn[y], z);
    return ;
}

ll Query(int x, int y) {
    ll ans = 0;
    while (tp[x] != tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) {
            swap(x, y);
        }
        ans = (ans + query(1, 1, n, dfn[tp[x]], dfn[x])) % mod;
        x = fa[tp[x]];
    }
    if (dep[x] > dep[y]) {
        swap(x, y);
    }
    ans = (ans + query(1, 1, n, dfn[x], dfn[y])) % mod;
    return ans;
}

這兩段程式碼就是跳鏈的過程,可以這樣理解一下,如果兩個節點的鏈頂不相同,說明他們不在同一條鏈中,我們讓鏈頂深度大的節點向上跳(這樣可以防止跳過頭),在跳之前,先對這段鏈的資訊進行修改維護,也就是 Modify 中的 modify 函式和 Query 中的 query 函式,然後,跳到這個鏈頂的父親,離開這條鏈,以此繼續,直到這兩個節點鏈頂一樣時(即在同一條鏈上時),對這兩個節點之間的鏈進行操作,退出函式。

// The code was written by yifan, and yifan is neutral!!!

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define bug puts("NOIP rp ++!");
#define rep(i, a, b, c) for (int i = (a); i <= (b); i += (c))
#define per(i, a, b, c) for (int i = (a); i >= (b); i -= (c))
#define Mod(x) ((x) >= mod ? (x) %= mod : (x))

template<typename T>
inline T read() {
    T x = 0;
    bool fg = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        fg |= (ch == '-');
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    return fg ? ~x + 1 : x;
}

const int N = 1e5 + 5;

int n, m, rt, mod, tim;
int val[N], siz[N], son[N], fa[N];
int dfn[N], pos[N], tp[N], dep[N];
vector<int> e[N];

struct seg {
    int len;
    ll val, tag;
} t[N << 2];

void dfs(int u, int fat) {
    siz[u] = 1;
    fa[u] = fat;
    dep[u] = dep[fat] + 1;
    for (int v : e[u]) {
        if (v == fat)   continue ;
        dfs(v, u);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) {
            son[u] = v;
        }
    }
}

void getpos(int u, int top) {
    dfn[u] = ++ tim;
    pos[tim] = u;
    tp[u] = top;
    if (!son[u])    return ;
    getpos(son[u], top);
    for (int v : e[u]) {
        if (v == fa[u] || v == son[u])  continue ;
        getpos(v, v);
    }
}

#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)

void pushup(int u) {
    t[u].val = (t[ls].val + t[rs].val) % mod;
}

void pushdown(int u, int l, int r) {
    if (!t[u].tag)  return ;
    if (l == r) {
        t[u].tag = 0;
        return ;
    }
    t[ls].tag = (t[ls].tag + t[u].tag) % mod;
    t[ls].val = (t[ls].val + t[u].tag * t[ls].len % mod) % mod;
    t[rs].tag = (t[rs].tag + t[u].tag) % mod;
    t[rs].val = (t[rs].val + t[u].tag * t[rs].len % mod) % mod;
    t[u].tag = 0;
    return ;
}

void build(int u, int l, int r) {
    t[u].tag = 0;
    t[u].len = r - l + 1;
    if (l == r) {
        t[u].val = val[pos[l]];
        return ;
    }
    build(ls, l, mid);
    build(rs, mid + 1, r);
    pushup(u);
}

void modify(int u, int l, int r, int lr, int rr, ll v) {
    if (lr <= l && r <= rr) {
        t[u].tag = (t[u].tag + v) % mod;
        t[u].val = (t[u].val + (t[u].len * v) % mod) % mod;
        return ;
    }
    pushdown(u, l, r);
    if (lr <= mid) {
        modify(ls, l, mid, lr, rr, v);
    }
    if (rr > mid) {
        modify(rs, mid + 1, r, lr, rr, v);
    }
    pushup(u);
}

void Modify(int x, int y, ll z) {
    while (tp[x] != tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) {
            swap(x, y);
        }
        modify(1, 1, n, dfn[tp[x]], dfn[x], z);
        x = fa[tp[x]];
    }
    if (dep[x] > dep[y]) {
        swap(x, y);
    }
    modify(1, 1, n, dfn[x], dfn[y], z);
    return ;
}

ll query(int u, int l, int r, int lr, int rr) {
    if (lr <= l && r <= rr) {
        return t[u].val;
    }
    pushdown(u, l, r);
    ll ans = 0;
    if (lr <= mid) {
        ans = (ans + query(ls, l, mid, lr, rr)) % mod;
    }
    if (rr > mid) {
        ans = (ans + query(rs, mid + 1, r, lr, rr)) % mod;
    }
    return ans;
}

ll Query(int x, int y) {
    ll ans = 0;
    while (tp[x] != tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) {
            swap(x, y);
        }
        ans = (ans + query(1, 1, n, dfn[tp[x]], dfn[x])) % mod;
        x = fa[tp[x]];
    }
    if (dep[x] > dep[y]) {
        swap(x, y);
    }
    ans = (ans + query(1, 1, n, dfn[x], dfn[y])) % mod;
    return ans;
}

#undef ls
#undef rs
#undef mid

int main() {
    n = read<int>(), m = read<int>();
    rt = read<int>(), mod = read<int>();
    rep (i, 1, n, 1) {
        val[i] = read<int>();
    }
    int x, y;
    rep (i, 1, n - 1, 1) {
        x = read<int>(), y = read<int>();
        e[x].emplace_back(y);
        e[y].emplace_back(x);
    }
    dfs(rt, 0);
    getpos(rt, rt);
    build(1, 1, n);
    int op, z;
    rep (i, 1, m, 1) {
        op = read<int>(), x = read<int>();
        if (op == 1) {
            y = read<int>(), z = read<ll>();
            Modify(x, y, z);
        }
        if (op == 2) {
            y = read<int>();
            cout << Query(x, y) % mod << '\n';
        }
        if (op == 3) {
            z = read<ll>();
            modify(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, z);
        }
        if (op == 4) {
            cout << query(1, 1, n, dfn[x], dfn[x] + siz[x] - 1) % mod << '\n';
        }
    }
    return 0;
}

P4211 [LNOI2014] LCA - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)

詢問 LCA 的深度,其實就是在樹上差分中,一個節點權值 \(+ 1\),另一個點求該節點到根節點的路徑和。在樹鏈剖分中,就是將根節點到一個節點這條鏈上所有的點 \(+ 1\),另一個節點求該節點到根節點的路徑的權值和。

可以經詢問進行離線處理,離線來完成這道題,具體看程式碼。

//The code was written by yifan, and yifan is neutral!!!

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define bug puts("NOIP rp ++!");
#define lowbit(x) (x & (-x))
#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)

template<typename T>
inline T read() {
    T x = 0;
    bool fg = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        fg |= (ch == '-');
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    return fg ? ~x + 1 : x;
}

const int N = 5e4 + 5;
const int mod = 201314;

int n, m;
int fa[N], dep[N], siz[N], son[N], tp[N], pos[N];
int dfn[N];
vector<int> e[N];

struct ask {
    int be, R, z;
    bool fg;

    int operator < (const ask& b) const {
        return R < b.R;
    }
} xunwen[N << 1];

struct seg {
    int val, tag;

    seg operator + (const seg& b) {
        seg& a = *this, res;
        res.val = (a.val + b.val) % mod;
        return res;
    }
} t[N << 2];

struct ANS {
    ll ans1, ans2;
} ans[N << 1];

void dfs(int u) {
    siz[u] = 1;
    dep[u] = dep[fa[u]] + 1;
    for (int v : e[u]) {
        if (v == fa[u]) continue ;
        dfs(v);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) {
            son[u] = v;
        }
    }
}

void gettop(int u, int Top) {
    static int t = 0;
    tp[u] = Top;
    dfn[u] = ++ t;
    pos[t] = u;
    if (!son[u])    return ;
    gettop(son[u], Top);
    for (int v : e[u]) {
        if (v == fa[u] || v == son[u]) continue ;
        gettop(v, v);
    }
}

void color(int u, int l, int r, int co) {
    t[u].val = (t[u].val + (r - l + 1) * co) % mod;
    if (l < r) {
        t[u].tag = (t[u].tag + co) % mod;
    }
}

void pushdown(int u, int l, int r) {
    if (t[u].tag && l < r) {
        color(ls, l, mid, t[u].tag);
        color(rs, mid + 1, r, t[u].tag);
    }
    t[u].tag = 0;
}

void modify(int u, int l, int r, int lr, int rr) {
    if (lr <= l && r <= rr) {
        color(u, l, r, 1);
        return ;
    }
    pushdown(u, l, r);
    if (lr <= mid) {
        modify(ls, l, mid, lr, rr);
    }
    if (rr > mid) {
        modify(rs, mid + 1, r, lr, rr);
    }
    t[u] = t[ls] + t[rs];
}

void Modify(int x, int y) {
    while (tp[x] != tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) {
            swap(x, y);
        }
        modify(1, 1, n, dfn[tp[x]], dfn[x]);
        x = fa[tp[x]];
    }
    if (dep[x] > dep[y]) {
        swap(x, y);
    }
    modify(1, 1, n, dfn[x], dfn[y]);
}

ll query(int u, int l, int r, int lr, int rr) {
    if (lr <= l && r <= rr) {
        return t[u].val;
    }
    pushdown(u, l, r);
    ll ans = 0;
    if (lr <= mid) {
        ans += query(ls, l, mid, lr, rr);
    }
    if (rr > mid) {
        ans += query(rs, mid + 1, r, lr, rr);
    }
    return ans % mod;
}

ll Query(int x, int y) {
    ll ans = 0;
    while (tp[x] != tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) {
            swap(x, y);
        }
        ans += query(1, 1, n, dfn[tp[x]], dfn[x]);
        x = fa[tp[x]];
    }
    if (dep[x] > dep[y]) {
        swap(x, y);
    }
    ans += query(1, 1, n, dfn[x], dfn[y]);
    return ans % mod;
}

int main() {
    n = read<int>(), m = read<int>();
    for (int i = 2; i <= n; ++ i) {
        fa[i] = read<int>() + 1;
        e[fa[i]].emplace_back(i);
        e[i].emplace_back(fa[i]);
    }
    for (int i = 1; i <= m; ++ i) {
        int l = read<int>(), r = read<int>() + 1, z = read<int>() + 1;
        xunwen[i] = ask{i, l, z, 0};
        xunwen[i + m] = ask{i, r, z, 1};
    }
    dep[1] = 1;
    dfs(1);
    gettop(1, 1);
    m <<= 1;
    sort(xunwen + 1, xunwen + m + 1, [](ask& a, ask& b) {
        return a.R < b.R;
    });
    int now = 0;
    for (int i = 1; i <= m; ++ i) {
        while (now < xunwen[i].R) {
            Modify(1, ++ now);
        }
        int j = xunwen[i].be;
        if (xunwen[i].fg) {
            ans[j].ans1 = Query(1, xunwen[i].z);
        } else {
            ans[j].ans2 = Query(1, xunwen[i].z);
        }
    }
    m >>= 1;
    for (int i = 1; i <= m; ++ i) {
        printf("%lld\n", (ans[i].ans1 - ans[i].ans2 + mod) % mod);
    }
    return 0;
}

P4216 [SCOI2015] 情報傳遞 - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)

這個題就是樹鏈剖分與樹狀陣列的搭配,由於風險值會隨著時間變化,風險值有一個限度,我們可以利用當前時間減去風險值來得到一個時間節點,在這個時間節點之前開始蒐集情報的人就會產生威脅。

由於只是對一條鏈產生威脅,所以可以使用差分,用樹狀陣列來維護。

//The code was written by yifan, and yifan is neutral!!!

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define bug puts("NOIP rp ++!");
#define lowbit(x) (x & (-x))

template<typename T>
inline T read() {
    T x = 0;
    bool fg = 0;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        fg |= (ch == '-');
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    return fg ? ~x + 1 : x;
}

const int N = 2e5 + 5;

int n, q, rt, tot;
int fa[N], dep[N], siz[N], son[N], tp[N];
int opt[N], X[N], Y[N], L[N], R[N];
ll t[N], ans[N];
vector<int> e[N], tim[N];

void dfs(int u) {
    L[u] = ++ tot;
    dep[u] = dep[fa[u]] + 1;
    siz[u] = 1;
    for (int &v : e[u]) {
        dfs(v);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) {
            son[u] = v;
        }
    }
    R[u] = tot;
}

void gettop(int u, int Top) {
    tp[u] = Top;
    if (!son[u]) return ;
    gettop(son[u], Top);
    for (int v : e[u]) {
        if (v == son[u])    continue ;
        gettop(v, v);
    }
}

int Lca(int x, int y) {
    while (tp[x] ^ tp[y]) {
        if (dep[tp[x]] < dep[tp[y]]) {
            swap(x, y);
        }
        x = fa[tp[x]];
    }
    if (dep[x] > dep[y]) {
        swap(x, y);
    }
    return x;
}

void modify(int x, int v) {
    while (x <= n) {
        t[x] += v;
        x += lowbit(x);
    }
}

ll query(int x) {
    ll ans = 0;
    while (x) {
        ans += t[x];
        x -= lowbit(x);
    }
    return ans;
}

ll Query(int x, int y) {
    int lca = Lca(x, y);
    return query(L[x]) + query(L[y]) - query(L[lca]) - query(L[fa[lca]]);
}

int dis(int x, int y) {
    int lca = Lca(x, y);
    return dep[x] + dep[y] - 2 * dep[lca] + 1;
}

int main() {
    n = read<int>();
    for (int i = 1; i <= n; ++ i) {
        fa[i] = read<int>();
        e[fa[i]].emplace_back(i);
    }
    for (rt = 1; fa[rt]; rt = fa[rt]);
    dfs(rt);
    gettop(rt, rt);
    q = read<int>();
    for (int i = 1; i <= q; ++ i) {
        opt[i] = read<int>();
        if (opt[i] == 1) {
            X[i] = read<int>(), Y[i] = read<int>();
            int c = read<int>();
            if (c < i) {
                tim[i - c - 1].emplace_back(i);
            }
        } else {
            X[i] = read<int>();
        }
    }
    for (int i = 1; i <= q; ++ i) {
        if (opt[i] == 2) {
            modify(L[X[i]], 1);
            modify(R[X[i]] + 1, -1);
        } 
        for (int &j : tim[i]) {
            ans[j] = Query(X[j], Y[j]);
        }
        if (opt[i] == 1) {
            printf("%d %lld\n", dis(X[i], Y[i]), ans[i]);
        }
    }
    return 0;
}

相關文章