abc343F 區間第2大的出現次數

chenfy27發表於2024-03-17

給定陣列a[n],有Q組操作,格式為:

  • 1 p x,將a[p]修改為x;
  • 2 l r,查詢區間[l,r]內第2大元素的出現次數。

1<=n,q<=2e5; 1<=a[i]<=1e9

用線段樹維護各個區間的最大及次大元素的出現次數,合併時最多隻保留兩條記錄。

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=b;i>=a;i--)

const int N = 200005;
int n, Q, d[4*N];
map<int,int> cnt[4*N];

void pushup(int x) {
    cnt[x].clear();
    for (auto [k,v] : cnt[2*x+1]) cnt[x][k] += v;
    for (auto [k,v] : cnt[2*x+2]) cnt[x][k] += v;
    while (cnt[x].size() > 2) {
        cnt[x].erase(cnt[x].begin());
    }
}
void build(int x, int l, int r) {
    if (l+1 == r) {
        cin >> d[x];
        cnt[x][d[x]] = 1;
        return;
    }
    int m = (l+r) / 2;
    build(2*x+1, l, m);
    build(2*x+2, m, r);
    pushup(x);
}
void update(int x, int l, int r, int u, int v) {
    if (l+1 == r) {
        d[x] = v;
        cnt[x].clear();
        cnt[x][d[x]] = 1;
        return;
    }
    int m = (l+r) / 2;
    if (u < m)
        update(2*x+1, l, m, u, v);
    else
        update(2*x+2, m, r, u, v);
    pushup(x);
}
map<int,int> query(int x, int l, int r, int L, int R) {
    map<int,int> ret;
    if (R <= l || r <= L) return ret;
    if (L <= l && r <= R) return cnt[x];
    int m = (l+r) / 2;
    map<int,int> z1 = query(2*x+1, l, m, L, R);
    map<int,int> z2 = query(2*x+2, m, r, L, R);
    for (auto [k,v] : z1) ret[k] += v;
    for (auto [k,v] : z2) ret[k] += v;
    while (ret.size() > 2) {
        ret.erase(ret.begin());
    }
    return ret;
}
void solve() {
    cin >> n >> Q;
    build(0, 1, n+1);
    rep(i,1,Q) {
        int op, x, y;
        cin >> op >> x >> y;
        if (op == 1) {
            update(0, 1, n+1, x, y);
        } else {
            auto t = query(0, 1, n+1, x, y+1);
            if (t.size() < 2) {
                cout << 0 << "\n";
            } else {
                auto it = t.begin();
                cout << it->second << "\n";
            }
        }
    }
}

signed main() {
    cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}

相關文章