lgP1253 扶蘇的問題

chenfy27發表於2024-07-14

給定長度為n的序列A,有如下3種操作:

  • 1 l r x,將區間[l,r]中的每個數都修改為x。
  • 2 l r x,將區間[l,r]中的每個數都加上x。
  • 3 l r,查詢區間[l,r]內的最大值。

分析:設定2個懶標記,先處理賦值標記,再處理增加標記。

#include <bits/stdc++.h>
using llong = long long;

const llong inf = 1e18;

template<class Info, class Tag>
struct LazySegmentTree {
    int n;
    std::vector<Info> info;
    std::vector<Tag> tag;
    LazySegmentTree():n(0) {}
    LazySegmentTree(int _n, Info v=Info()) {
        init(_n, v);
    }
    template<class T>
    LazySegmentTree(std::vector<T> v) {
        init(v);
    }
    void init(int _n, Info v=Info()) {
        init(std::vector<Info>(_n, v));
    }
    template<class T>
    void init(std::vector<T> v) {
        n = v.size();
        info.assign(4 << std::__lg(n), Info());
        tag.assign(4 << std::__lg(n), Tag());
        std::function<void(int,int,int)> build = [&](int x, int l, int r) {
            if (l + 1 == r) {
                info[x] = v[l];
                return;
            }
            int m = (l + r) / 2;
            build(2*x+1, l, m);
            build(2*x+2, m, r);
            pushup(x);
        };
        build(0, 0, n);
    }
    void pushup(int x) {
        info[x] = info[2*x+1] + info[2*x+2];
    }
    void apply(int x, const Tag &t) {
        info[x].apply(t);
        tag[x].apply(t);
    }
    void pushdown(int x) {
        apply(2*x+1, tag[x]);
        apply(2*x+2, tag[x]);
        tag[x] = Tag();
    }
    void assign(int x, int l, int r, int p, const Info &v) {
        if (l + 1 == r) {
            info[x] = v;
            return;
        }
        int m = (l + r) / 2;
        pushdown(x);
        if (p < m) {
            assign(2*x+1, l, m, p, v);
        } else {
            assign(2*x+2, m, r, p, v);
        }
        pushup(x);
    }
    void assign(int p, const Info &v) {
        assign(0, 0, n, p, v);
    }
    void add(int x, int l, int r, int p, const Info &v) {
        if (l + 1 == r) {
            info[x] += v;
            return;
        }
        int m = (l + r) / 2;
        pushdown(x);
        if (p < m) {
            add(2*x+1, l, m, p, v);
        } else {
            add(2*x+2, m, r, p, v);
        }
        pushup(x);
    }
    void add(int p, const Info &v) {
        add(0, 0, n, p, v);
    }
    Info rangeQuery(int x, int l, int r, int L, int R) {
        if (R <= l || r <= L) {
            return Info();
        }
        if (L <= l && r <= R) {
            return info[x];
        }
        int m = (l + r) / 2;
        pushdown(x);
        return rangeQuery(2*x+1, l, m, L, R) + rangeQuery(2*x+2, m, r, L, R);
    }
    Info rangeQuery(int L, int R) {
        return rangeQuery(0, 0, n, L, R);
    }
    void rangeApply(int x, int l, int r, int L, int R, const Tag &t) {
        if (R <= l || r <= L) {
            return;
        }
        if (L <= l && r <= R) {
            apply(x, t);
            return;
        }
        int m = (l + r) / 2;
        pushdown(x);
        rangeApply(2*x+1, l, m, L, R, t);
        rangeApply(2*x+2, m, r, L, R, t);
        pushup(x);
    }
    void rangeApply(int L, int R, const Tag &t) {
        return rangeApply(0, 0, n, L, R, t);
    }
    template<class F>
    int findFirst(int x, int l, int r, int L, int R, F && pred) {
        if (R <= l || r <= L) {
            return -1;
        }
        if (L <= l && r <= R && !pred(info[x])) {
            return -1;
        }
        if (l + 1 == r) {
            return l;
        }
        int m = (l + r) / 2;
        pushdown(x);
        int res = findFirst(2*x+1, l, m, L, R, pred);
        if (res == -1) {
            res = findFirst(2*x+2, m, r, L, R, pred);
        }
        return res;
    }
    template<class F>
    int findFirst(int L, int R, F &&pred) {
        return findFirst(0, 0, n, L, R, pred);
    }
    template<class F>
    int findLast(int x, int l, int r, int L, int R, F &&pred) {
        if (R <= l || r <= L) {
            return -1;
        }
        if (L <= l && r <= R && !pred(info[x])) {
            return -1;
        }
        if (l + 1 == r) {
            return l;
        }
        int m = (l + r) / 2;
        pushdown(x);
        int res = findLast(2*x+2, m, r, L, R, pred);
        if (res == -1) {
            res = findLast(2*x+1, l, m, L, R, pred);
        }
        return res;
    }
    template<class F>
    int findLast(int L, int R, F &&pred) {
        return findLast(0, 0, n, L, R, pred);
    }
    void print(int x, int l, int r) {
        std::cerr << x << "(" << l << "," << r << ") " << info[x] << " " << tag[x] << "\n";
        if (l + 1 < r) {
            int m = (l + r) / 2;
            print(2*x+1, l, m);
            print(2*x+2, m, r);
        }
    }
    void print() {
        std::cerr << "-----------------------------\n";
        print(0, 0, n);
    }
};

struct Tag {
    llong set;
    llong add;
    Tag(llong s=inf, llong a=inf):set(s),add(a) {}
    void apply(Tag t) {
        if (t.set != inf) {
            set = t.set;
            add = inf;
        }
        if (t.add != inf) {
            if (add == inf) {
                add = t.add;
            } else {
                add += t.add;
            }
        }
    }
    friend std::ostream& operator<<(std::ostream &out, Tag &tag) {
        out << "tag:(" << tag.set << "," << tag.add << ")";
        return out;
    }
};

struct Info {
    llong max;
    Info(llong a=-inf):max(a) {}
    void apply(Tag t) {
        if (t.set != inf) {
            max = t.set;
        }
        if (t.add != inf) {
            max += t.add;
        }
    }
    friend Info operator+(const Info &a, const Info &b) {
        Info ans;
        ans.max = std::max(a.max, b.max);
        return ans;
    }
    friend std::ostream& operator<<(std::ostream &out, Info &info) {
        out << "info:(" << info.max << ")";
        return out;
    }
};

void solve() {
    int n, m;
    std::cin >> n >> m;
    std::vector<llong> A(n);
    for (int i = 0; i < n; i++) {
        std::cin >> A[i];
    }
    LazySegmentTree<Info,Tag> tr(A);
    for (int i = 0; i < m; i++) {
        int op, l, r, x;
        std::cin >> op >> l >> r;
        if (op == 1) {
            std::cin >> x;
            tr.rangeApply(l-1, r, Tag(x, inf));
        } else if (op == 2) {
            std::cin >> x;
            tr.rangeApply(l-1, r, Tag(inf, x));
        } else if (op == 3) {
            std::cout << tr.rangeQuery(l-1, r).max << "\n";
        }
    }
}

int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}