線段樹維護區間等差數列

zhujio發表於2024-03-10

線段樹維護區間等差數列

我們採用用兩個懶標記分別維護等 差數列首項 k 和 公差 d

維護時有個細節是假如我有左右兩個區間需要合併資訊時

我們對於左邊還是 k 和 d

但是對於右邊資訊此時 k 應該變成 k + len * d, 公差還是 d

len表示的是右邊區間長度

牛牛的等差數列

線段樹維護區間等差數列
#include <bits/stdc++.h>
using namespace std;
#define endl "\n"
#define int long long
typedef long long ll;

const int N = 2e5 + 100;
const int mod = 111546435, inv2 = 55773218;

int a[N];

struct node {
    int sum, k, d;
} Seg[N * 4];

void settag(int id, int k, int d, int len) {
    Seg[id].k += k; Seg[id].d += d;
    Seg[id].k %= mod; Seg[id].d %= mod;
    Seg[id].sum += (k + k + d * (len - 1) % mod) % mod * len % mod * inv2 % mod;
}

void pushdown(int id, int l, int r) {
    int mid = l + r >> 1;
    int k = Seg[id].k, d = Seg[id].d;
    if (k || d) {
        settag(id * 2, k, d, mid - l + 1);
        k += (mid - l + 1) * d; k %= mod;
        settag(id * 2 + 1, k, d, r - mid);
        Seg[id].k = Seg[id].d = 0;
    }
}

void pushup(int id) {
    Seg[id].sum = (Seg[id * 2].sum + Seg[id * 2 + 1].sum) % mod;
}

void build(int id, int l, int r) {
    Seg[id] = {0, 0, 0};
    if (l == r) {
        Seg[id].sum = a[l] % mod;
        return;
    }
    int mid = l + r >> 1;
    build(id * 2, l, mid); build(id * 2 + 1, mid + 1, r);
    pushup(id);
}    

void modify(int id, int l, int r, int x, int y, int k, int d) {
    if (x <= l && y >= r) {
        int len = r - l + 1;
        settag(id, k, d, len);
        return;
    }
    int mid = l + r >> 1;
    pushdown(id, l, r);
    if (x > mid) modify(id * 2 + 1, mid + 1, r, x, y, k, d);
    else if (y <= mid) modify(id * 2, l, mid, x, y, k, d);
    else {
        modify(id * 2, l, mid, x, mid, k, d);
        int len = mid - x + 1;
        k += len * d; k %= mod;
        modify(id * 2 + 1, mid + 1, r, mid + 1, y, k, d);
    }
    pushup(id);
}

int query(int id, int l, int r, int x, int y) {
    if (x <= l && y >= r) return Seg[id].sum;
    pushdown(id, l, r);
    int mid = l + r >> 1, ans = 0;
    if (x <= mid) ans += query(id * 2, l, mid, x, y), ans %= mod;
    if (y > mid) ans += query(id * 2 + 1, mid + 1, r, x, y), ans %= mod;
    return ans;
}

void solve() {
    int n; cin >> n;
    for (int i = 1; i <= n; i++) cin >> a[i];
    build(1, 1, n);
    int q; cin >> q;
    while (q--) {
        int op; cin >> op;
        if (op == 1) {
            int l, r, k, d; cin >> l >> r >> k >> d;
            modify(1, 1, n, l, r, k % mod, d % mod);
        } else {
            int l, r, m; cin >> l >> r >> m;
            cout << query(1, 1, n, l, r) % m << endl;
        }
    }
}

signed main() {
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);

    solve();

    return 0;
}
View Code

Space Harbour
這題和上題差不多,但是不知道為什麼我線段樹一直越界

線上段樹各個函式中特判了下邊界才過

線段樹維護區間等差數列
#include <bits/stdc++.h>
using namespace std;
#define endl "\n"
#define int long long
typedef long long ll;

const int N = 1e6 + 100;

int a[N], v[N], n, m, q;
set<int> s;

struct node {
    int sum, k, d;
} Seg[N * 4];

bool check(int x) {
    return x >= 1 && x <= n;
}

void pushup(int id) {
    if (id * 2 + 1 > N * 3) return;
    Seg[id].sum = Seg[id * 2].sum + Seg[id * 2 + 1].sum;
}

void settag(int id, int k, int d, int len) {
    if (id > N * 3) return;
    Seg[id].k = k; Seg[id].d = d;
    Seg[id].sum = (k + k + d * (len - 1)) * len / 2;
}

void pushdown(int id, int l, int r) {
    if (id > N * 3) return;
    int mid = l + r >> 1;
    int k = Seg[id].k, d = Seg[id].d;
    if (k) {
        settag(id * 2, k, d, mid - l + 1);
        k += (mid - l + 1) * d;
        settag(id * 2 + 1, k, d, r - mid);
        Seg[id].k = Seg[id].d = 0;
    }
}

void build(int id, int l, int r) {
    if (id > N * 3) return;
    if (l == r) {
        if (v[l]) Seg[id].sum = 0;
        else {
            auto it = s.lower_bound(l);
            if (check(*it) && check(*prev(it))) Seg[id].sum = (*it - l) * (v[*prev(it)]);
            else Seg[id].sum = 0;
        }
        // if (l == 2) cout << a[l] << ' ' << Seg[id].sum << endl;
        return;
    }
    int mid = l + r >> 1;
    build(id * 2, l, mid); build(id * 2 + 1, mid + 1, r);
    pushup(id);
}

void modify(int id, int l, int r, int x, int y, int k, int d) {
    if (id > N * 3) return;
    if (x == l && y == r) {
        settag(id, k, d, r - l + 1);
        return;
    }
    int mid = l + r >> 1;
    pushdown(id, l, r);
    if (x > mid) modify(id * 2 + 1, mid + 1, r, x, y, k, d);
    else if (y <= mid) modify(id * 2, l, mid, x, y, k, d);
    else {
        modify(id * 2, l, mid, x, mid, k, d);
        int len = mid - x + 1;
        k += len * d;
        modify(id * 2 + 1, mid + 1, r, mid + 1, y, k, d);
    }
    pushup(id);
}

int query(int id, int l, int r, int x, int y) {
    if (id > N * 3) return 0;
    if (x <= l && y >= r) return Seg[id].sum;
    int mid = l + r >> 1, ans = 0;
    pushdown(id, l, r);
    if (x <= mid) ans += query(id * 2, l, mid, x, y);
    if (y > mid) ans += query(id * 2 + 1, mid + 1, r, x, y);
    return ans;
}   

void solve() {
    cin >> n >> m >> q;
    for (int i = 1; i <= m; i++) cin >> a[i], s.insert(a[i]);
    for (int i = 1; i <= m; i++) cin >> v[a[i]];

    build(1, 1, n);

    while (q--) {
        int op, x, y; cin >> op >> x >> y;
        if (op == 1) {
            v[x] = y;
            auto it = s.lower_bound(x);
            int l = *prev(it), r = *it;
            if (check(l) && check(r)) {
                modify(1, 1, n, l + 1, x - 1, v[l] * (x - l - 1), -v[l]);
                modify(1, 1, n, x, x, 0, 0);
                modify(1, 1, n, x + 1, r - 1, y * (r - x - 1), -y);   
            } else if (check(l) && !check(r)) {
                modify(1, 1, n, l + 1, x - 1, v[l] * (x - l - 1), -v[l]);
                modify(1, 1, n, x, x, 0, 0);
            } else if (check(r) && !check(l)) {
                modify(1, 1, n, x, x, 0, 0);
                modify(1, 1, n, x + 1, r - 1, y * (r - x - 1), -y); 
            } else {
                modify(1, 1, n, x, x, 0, 0);
            }
            s.insert(x);
        } else {
            cout << query(1, 1, n, x, y) << endl;
        }
    }
}

signed main() {
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);

    solve();

    return 0;
}
View Code

相關文章