線段樹也能是 Trie 樹 題解

XuYueming發表於2024-11-02

題意簡述

給定一個長為 \(n = 2^k\) 的序列 \(\{a_0, \ldots, a_{n - 1}\}\),你需要使用資料結構維護它,支援 \(m\) 次以下操作:

  1. 單點加:\(a_x \gets a_x + y\)
  2. 區間查:\(\sum \limits _ {i = l} ^ r a_i\)
  3. 全域性下標與:\(a'_{i \operatorname{and} x} \gets a_{i}\),即把 \(a_i\) 累加到新的 \(a\) 的第 \(i \operatorname{and} x\) 位上;
  4. 全域性下標或:\(a'_{i \operatorname{or} x} \gets a_{i}\)
  5. 全域性下標異或:\(a'_{i \operatorname{xor} x} \gets a_{i}\)

\(k \leq 19\)\(m \leq 2^{19}\)

題目分析

資料結構題,發現 \(n = 2^k\),於是想到線段樹,發現此時線段樹也是一棵對於下標的 Trie 樹,那麼我們只需要思考對於下標的位運算怎麼處理。

按位考慮,套路化地發現,\(\operatorname{xor}\) 的本質是將特定的某些位翻轉,即如果 \(x\) 的第 \(i\) 位上為 \(1\),那麼線段樹上,從葉子往根數第 \(i\) 層(葉子為第 \(-1\) 層)的左右兒子交換;\(\operatorname{and}\) 的本質是將某些特定的位設為 \(0\),即若 \(x\)\(i\) 位為 \(0\),那麼線段樹上第 \(i\) 層的右兒子合併到左兒子上,合併為線段樹合併;類似 \(\operatorname{or}\) 的本質是將某些特定的位設為 \(1\),即若 \(x\)\(i\) 位為 \(1\),那麼線段樹上第 \(i\) 層的左兒子合併到右兒子上,合併為線段樹合併。

我們先來考慮線段樹合併的正確性。我們知道,線段樹合併的時間複雜度約等於兩棵線段樹重合的結點數,約等於較小的線段樹結點數。由於合併後也會減小這麼多點,所以整個過程下來,勢能分析知,複雜度為整個過程線段樹結點數。除了初始 \(n\) 和節點外,還有操作 \(1\) 帶來的 \(\mathcal{O}(m \log n)\) 個結點,其他操作不會帶來新的節點,所以總的時間複雜度為 \(\mathcal{O}(n + m \log n)\),是正確的。

但是顯然不可以每次暴力對那麼多結點依次進行合併操作、交換操作,所以需要打懶惰標記。標記在層與層之間是獨立的。我們考慮某一層的操作時間軸,由三種操作構成。如果出現了一次合併操作,那麼之前的全部操作都會無效。所以,總是可以將時間軸簡化成等價的一次合併後跟著若干次交換子樹。這樣標記也就好打了。

時間複雜度:\(\mathcal{O}(n + m \log n)\)

程式碼

#include <cstdio>
#include <iostream>
using namespace std;

const int MAX = 1 << 27;
char buf[MAX], *p = buf, obuf[MAX], *op = obuf;
#ifdef XuYueming
# define fread(_, __, ___, ____)
#else
# define getchar() *p++
#endif
#define putchar(x) *op++ = x
#define isdigit(x) ('0' <= x && x <= '9')
#define __yzh__(x) for (; x isdigit(ch); ch = getchar())
template <typename T>
inline void read(T &x) {
    x = 0; char ch = getchar(); __yzh__(!);
    __yzh__( ) x = (x << 3) + (x << 1) + (ch ^ 48);
}
template <typename T>
inline void write(T x) {
    static short stack[20], top(0);
    do stack[++top] = x % 10; while (x /= 10);
    while (top) putchar(stack[top--] | 48);
    putchar('\n');
}

const int K = 20, N = 1 << K | 10;

using lint = long long;

int n, m, k;
int tim[K], ctim[K];  // 總時間,最後一次左子樹合併到右子樹的時間

namespace Segment_Tree {
    int root, tot;
    
    struct node {
        int ls, rs;
        int t;
        lint sum;
    } tree[N * 2 + N * K];
    
    #define Ls tree[idx].ls
    #define Rs tree[idx].rs
    
    int combine(int, int, int);
    
    inline void pushdown(int idx, int dpt) {
        if (tree[idx].t < ctim[dpt])
            Rs = combine(Ls, Rs, dpt - 1), Ls = 0, tree[idx].t = ctim[dpt];
        if ((tim[dpt] - tree[idx].t) & 1)
            swap(Ls, Rs);
        tree[idx].t = tim[dpt];
    }
    
    int combine(int u, int v, int dpt) {
        if (!u || !v) return u | v;
        tree[u].sum += tree[v].sum;
        if (!~dpt) return u;
        pushdown(u, dpt), pushdown(v, dpt);
        tree[u].ls = combine(tree[u].ls, tree[v].ls, dpt - 1);
        tree[u].rs = combine(tree[u].rs, tree[v].rs, dpt - 1);
        return u;
    }
    
    void modify(int &idx, int trl, int trr, int dpt, int p, int val) {
        if (p > trr || p < trl) return;
        if (!idx) idx = ++tot, ~dpt && (tree[idx].t = tim[dpt]);
        tree[idx].sum += val;
        if (!~dpt) return;
        int mid = (trl + trr) >> 1;
        pushdown(idx, dpt);
        modify(Ls, trl, mid, dpt - 1, p, val);
        modify(Rs, mid + 1, trr, dpt - 1, p, val);
    }
    
    lint query(int idx, int trl, int trr, int dpt, int l, int r) {
        if (trl > r || trr < l || !idx) return 0;
        if (l <= trl && trr <= r) return tree[idx].sum;
        pushdown(idx, dpt);
        int mid = (trl + trr) >> 1;
        return query(Ls, trl, mid, dpt - 1, l, r) + query(Rs, mid + 1, trr, dpt - 1, l, r);
    }
}

using Segment_Tree::root;
using Segment_Tree::modify;
using Segment_Tree::query;

signed main() {
    #ifndef XuYueming
    freopen("jeden.in", "r", stdin);
    freopen("jeden.out", "w", stdout);
    #endif
    fread(buf, 1, MAX, stdin);
    read(n), read(m), k = __lg(n);
    for (int i = 0, x; i < n; ++i) read(x), modify(root, 0, n - 1, k - 1, i, x);
    for (int x, y; m--; ) {
        char op; do op = getchar(); while (op < 'a' || op > 'z');
        op = getchar();
        if (op == 'd')
            read(x), read(y), modify(root, 0, n - 1, k - 1, x, y);
        else if (op == 'u')
            read(x), read(y), write(query(root, 0, n - 1, k - 1, x, y));
        else if (op == 'n') {
            read(x);
            for (int i = 0; i < k; ++i)
                if (!(x & 1 << i))
                    ctim[i] = ++tim[i], ++tim[i];
        } else if (op == 'r') {
            read(x);
            for (int i = 0; i < k; ++i)
                if (x & 1 << i)
                    ctim[i] = ++tim[i];
        } else {
            read(x);
            for (int i = 0; i < k; ++i)
                if (x & 1 << i)
                    ++tim[i];
        }
    }
    fwrite(obuf, 1, op - obuf, stdout);
    return 0;
}