acwing246 區間最大公約數

chenfy27發表於2024-06-18

給定長度為N的數列A,以及M條指令,每條指令可能是以下兩種之一:

  1. C l r d,表示把A[l],A[l+1],...A[r]都加上d。
  2. Q l r,表示查詢A[l],A[l+1],...A[r]的最大公約數。

對於每個詢問,輸出一個整數表示答案。

分析:利用差分陣列,將區間修改轉換成兩次單點修改。再用差分陣列構造出原陣列區間的最大公約數,例如gcd(a,b,c) = gcd(a,b-a,c-d) = gcd(a, gcd(b-a,c-d)),因此可以拆成兩部分計算,左邊是字首和,右邊是區間gcd,二者都可以用線段樹來維護。

#include <bits/stdc++.h>
using i64 = long long;

i64 mygcd(i64 x, i64 y) {
    return y ? mygcd(y, x%y) : x;
}

template<class Info>
struct SegmentTree {
    int n;
    std::vector<Info> info;
    SegmentTree():n(0) {}
    SegmentTree(int _n, Info v = Info()) {
        init(_n, v);
    }
    void init(int _n, Info v = Info()) {
        init(std::vector<Info>(_n, v));
    }
    template<class T>
    SegmentTree(std::vector<T> v) {
        init(v);
    }
    template<class T>
    void init(std::vector<T> v) {
        n = v.size();
        info.assign(4 << std::__lg(n), Info());
        std::function<void(int,int,int)> build = [&](int x, int l, int r) {
            if (l + 1 == r) {
                info[x] = v[l]; //fixme
                return;
            }
            int m = (l + r) / 2;
            build(2*x+1, l, m);
            build(2*x+2, m, r);
            pushup(x);
        };
        build(0, 0, n);
    }
    void pushup(int x) {
        info[x] = info[2*x+1] + info[2*x+2];
    }
    void modify(int x, int l, int r, int p, const Info &v) {
        if (l + 1 == r) {
            info[x].gcd = info[x].gcd + v.gcd; //fixme
            info[x].sum = info[x].sum + v.sum;
            return;
        }
        int m = (l + r) / 2;
        if (p < m) {
            modify(2*x+1, l, m, p, v);
        } else {
            modify(2*x+2, m, r, p, v);
        }
        pushup(x);
    }
    void modify(int p, const Info &v) {
        modify(0, 0, n, p, v);
    }
    Info rangeQuery(int x, int l, int r, int L, int R) {
        if (R <= l || r <= L) {
            return Info();
        }
        if (L <= l && r <= R) {
            return info[x];
        }
        int m = (l + r) / 2;
        return rangeQuery(2*x+1, l, m, L, R) + rangeQuery(2*x+2, m, r, L, R);
    }
    Info rangeQuery(int L, int R) {
        return rangeQuery(0, 0, n, L, R);
    }
    template<class F>
    int findFirst(int x, int l, int r, int L, int R, F pred) {
        if (R <= l || r <= L || !pred(info[x])) {
            return -1;
        }
        if (l + 1 == r) {
            return l;
        }
        int m = (l + r) / 2;
        int res = findFirst(2*x+1, l, m, L, R, pred);
        if (res == -1) {
            res = findFirst(2*x+2, m, r, L, R, pred);
        }
        return res;
    }
    template<class F>
    int findFirst(int L, int R, F pred) {
        return findFirst(0, 0, n, L, R, pred);
    }
    template<class F>
    int findLast(int x, int l, int r, int L, int R, F pred) {
        if (R <= l || r <= L || !pred(info[x])) {
            return -1;
        }
        if (l + 1 == r) {
            return l;
        }
        int m = (l + r) / 2;
        int res = findLast(2*x+2, m, r, L, R, pred);
        if (res == -1) {
            res = findLast(2*x+1, l, m, L, R, pred);
        }
        return res;
    }
    template<class F>
    int findLast(int L, int R, F pred) {
        return findLast(0, 0, n, L, R, pred);
    }
};

struct Info {
    i64 gcd, sum;
    Info(i64 v=0):gcd(v),sum(v) {}
    friend Info operator+(const Info &a, const Info &b) {
        Info ans;
        ans.gcd = mygcd(a.gcd, b.gcd);
        ans.sum = a.sum + b.sum;
        return ans;
    }
};

void solve() {
    int N, M;
    std::cin >> N >> M;
    std::vector<i64> A(N), B(N);
    for (int i = 0; i < N; i++) {
        std::cin >> A[i];
    }
    std::adjacent_difference(A.begin(), A.end(), B.begin());
    SegmentTree<Info> seg(B);
    for (int i = 0; i < M; i++) {
        std::string op;
        i64 l, r, d;
        std::cin >> op >> l >> r;
        l--, r--;
        if (op == "C") {
            std::cin >> d;
            seg.modify(l, Info(d));
            if (r + 1 < N) {
                seg.modify(r+1, Info(-d));
            }
        } else if (op == "Q") {
            i64 gcd1 = seg.rangeQuery(0, l+1).sum;
            i64 gcd2 = seg.rangeQuery(l+1, r+1).gcd;
            std::cout << std::abs(mygcd(gcd1, gcd2)) << "\n";
        }
    }
}

int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}

相關文章