題意簡述
給定一個長為 \(n = 2^k\) 的序列 \(\{a_0, \ldots, a_{n - 1}\}\),你需要使用資料結構維護它,支援 \(m\) 次以下操作:
- 單點加:\(a_x \gets a_x + y\);
- 區間查:\(\sum \limits _ {i = l} ^ r a_i\);
- 全域性下標與:\(a'_{i \operatorname{and} x} \gets a_{i}\),即把 \(a_i\) 累加到新的 \(a\) 的第 \(i \operatorname{and} x\) 位上;
- 全域性下標或:\(a'_{i \operatorname{or} x} \gets a_{i}\);
- 全域性下標異或:\(a'_{i \operatorname{xor} x} \gets a_{i}\);
\(k \leq 19\),\(m \leq 2^{19}\)。
題目分析
資料結構題,發現 \(n = 2^k\),於是想到線段樹,發現此時線段樹也是一棵對於下標的 Trie 樹,那麼我們只需要思考對於下標的位運算怎麼處理。
按位考慮,套路化地發現,\(\operatorname{xor}\) 的本質是將特定的某些位翻轉,即如果 \(x\) 的第 \(i\) 位上為 \(1\),那麼線段樹上,從葉子往根數第 \(i\) 層(葉子為第 \(-1\) 層)的左右兒子交換;\(\operatorname{and}\) 的本質是將某些特定的位設為 \(0\),即若 \(x\) 第 \(i\) 位為 \(0\),那麼線段樹上第 \(i\) 層的右兒子合併到左兒子上,合併為線段樹合併;類似 \(\operatorname{or}\) 的本質是將某些特定的位設為 \(1\),即若 \(x\) 第 \(i\) 位為 \(1\),那麼線段樹上第 \(i\) 層的左兒子合併到右兒子上,合併為線段樹合併。
我們先來考慮線段樹合併的正確性。我們知道,線段樹合併的時間複雜度約等於兩棵線段樹重合的結點數,約等於較小的線段樹結點數。由於合併後也會減小這麼多點,所以整個過程下來,勢能分析知,複雜度為整個過程線段樹結點數。除了初始 \(n\) 和節點外,還有操作 \(1\) 帶來的 \(\mathcal{O}(m \log n)\) 個結點,其他操作不會帶來新的節點,所以總的時間複雜度為 \(\mathcal{O}(n + m \log n)\),是正確的。
但是顯然不可以每次暴力對那麼多結點依次進行合併操作、交換操作,所以需要打懶惰標記。標記在層與層之間是獨立的。我們考慮某一層的操作時間軸,由三種操作構成。如果出現了一次合併操作,那麼之前的全部操作都會無效。所以,總是可以將時間軸簡化成等價的一次合併後跟著若干次交換子樹。這樣標記也就好打了。
時間複雜度:\(\mathcal{O}(n + m \log n)\)。
程式碼
#include <cstdio>
#include <iostream>
using namespace std;
const int MAX = 1 << 27;
char buf[MAX], *p = buf, obuf[MAX], *op = obuf;
#ifdef XuYueming
# define fread(_, __, ___, ____)
#else
# define getchar() *p++
#endif
#define putchar(x) *op++ = x
#define isdigit(x) ('0' <= x && x <= '9')
#define __yzh__(x) for (; x isdigit(ch); ch = getchar())
template <typename T>
inline void read(T &x) {
x = 0; char ch = getchar(); __yzh__(!);
__yzh__( ) x = (x << 3) + (x << 1) + (ch ^ 48);
}
template <typename T>
inline void write(T x) {
static short stack[20], top(0);
do stack[++top] = x % 10; while (x /= 10);
while (top) putchar(stack[top--] | 48);
putchar('\n');
}
const int K = 20, N = 1 << K | 10;
using lint = long long;
int n, m, k;
int tim[K], ctim[K]; // 總時間,最後一次左子樹合併到右子樹的時間
namespace Segment_Tree {
int root, tot;
struct node {
int ls, rs;
int t;
lint sum;
} tree[N * 2 + N * K];
#define Ls tree[idx].ls
#define Rs tree[idx].rs
int combine(int, int, int);
inline void pushdown(int idx, int dpt) {
if (tree[idx].t < ctim[dpt])
Rs = combine(Ls, Rs, dpt - 1), Ls = 0, tree[idx].t = ctim[dpt];
if ((tim[dpt] - tree[idx].t) & 1)
swap(Ls, Rs);
tree[idx].t = tim[dpt];
}
int combine(int u, int v, int dpt) {
if (!u || !v) return u | v;
tree[u].sum += tree[v].sum;
if (!~dpt) return u;
pushdown(u, dpt), pushdown(v, dpt);
tree[u].ls = combine(tree[u].ls, tree[v].ls, dpt - 1);
tree[u].rs = combine(tree[u].rs, tree[v].rs, dpt - 1);
return u;
}
void modify(int &idx, int trl, int trr, int dpt, int p, int val) {
if (p > trr || p < trl) return;
if (!idx) idx = ++tot, ~dpt && (tree[idx].t = tim[dpt]);
tree[idx].sum += val;
if (!~dpt) return;
int mid = (trl + trr) >> 1;
pushdown(idx, dpt);
modify(Ls, trl, mid, dpt - 1, p, val);
modify(Rs, mid + 1, trr, dpt - 1, p, val);
}
lint query(int idx, int trl, int trr, int dpt, int l, int r) {
if (trl > r || trr < l || !idx) return 0;
if (l <= trl && trr <= r) return tree[idx].sum;
pushdown(idx, dpt);
int mid = (trl + trr) >> 1;
return query(Ls, trl, mid, dpt - 1, l, r) + query(Rs, mid + 1, trr, dpt - 1, l, r);
}
}
using Segment_Tree::root;
using Segment_Tree::modify;
using Segment_Tree::query;
signed main() {
#ifndef XuYueming
freopen("jeden.in", "r", stdin);
freopen("jeden.out", "w", stdout);
#endif
fread(buf, 1, MAX, stdin);
read(n), read(m), k = __lg(n);
for (int i = 0, x; i < n; ++i) read(x), modify(root, 0, n - 1, k - 1, i, x);
for (int x, y; m--; ) {
char op; do op = getchar(); while (op < 'a' || op > 'z');
op = getchar();
if (op == 'd')
read(x), read(y), modify(root, 0, n - 1, k - 1, x, y);
else if (op == 'u')
read(x), read(y), write(query(root, 0, n - 1, k - 1, x, y));
else if (op == 'n') {
read(x);
for (int i = 0; i < k; ++i)
if (!(x & 1 << i))
ctim[i] = ++tim[i], ++tim[i];
} else if (op == 'r') {
read(x);
for (int i = 0; i < k; ++i)
if (x & 1 << i)
ctim[i] = ++tim[i];
} else {
read(x);
for (int i = 0; i < k; ++i)
if (x & 1 << i)
++tim[i];
}
}
fwrite(obuf, 1, op - obuf, stdout);
return 0;
}