【學習筆記】Segment Tree Beats/吉司機線段樹

liukejie發表於2024-11-16

連結

區間最值操作

HDU-5306

支援對區間取 \(\min\),維護區間 \(\max\),查詢區間和。

很容易想到一個暴力,我們每一次找出這個區間的最大值 \(mx\),如果 \(mx>x\),那麼暴力修改這個位置的值,否則已經修改完畢,退出,時間複雜度為 \(O(n^2 \log n)\)

打一打補丁,對線段樹上的每一個區間維護區間最大值 \(mx\),這個區間中最大值出現的次數 \(t\),區間次大值 \(se\),當然還要維護區間和 \(sum\)

現在考慮打上區間取 \(\min\) 標記

  • 如果 \(mx\le x\),那麼對 \(sum\) 就沒有修改。
  • 如果 \(se<x<mx\),那麼 \(sum=sum-(mx-x)\times t\)
  • 如果 \(x\le se<mx\),此時無法直接更新節點資訊,故向下左右子樹遞迴。我們分別 DFS 這個節點的兩個孩子,如果當前 DFS 的過程中遇到了前兩種情況,就直接修改打上標記然後退出,否則就繼續 DFS。
點選檢視程式碼
#include<bits/stdc++.h>
using namespace std;

#define ls p<<1
#define rs p<<1|1
#define ll long long
inline char gc()
{
    static char buf[1 << 20/*這裡很玄學,改成其他數字可能更快*/], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 20/*改成和上面一樣的數字*/, stdin), p1 == p2) ? EOF : *p1 ++;
}

inline void read(int &n) // 用法 read(n);
{
    bool w = 0;
    char c = gc();
    for(; c < 48 || c > 57; c = gc())
        w = c == 45;
    for(n = 0; c >= 48 && c <= 57; c = gc())
        n = n * 10 + c - 48;
    n = w ? -n : n;
}
const int N=1e6+7;

int T,n,m,a[N],mx[N<<2],se[N<<2],cnt[N<<2],tag[N<<2];
ll sum[N<<2];

inline void pushup(int p){
    sum[p]=sum[ls]+sum[rs];
    if(mx[ls]==mx[rs]){
        mx[p]=mx[ls],se[p]=max(se[ls],se[rs]);
        cnt[p]=cnt[ls]+cnt[rs];
    }
    else if(mx[ls]>mx[rs]){
        mx[p]=mx[ls],se[p]=max(se[ls],mx[rs]);
        cnt[p]=cnt[ls];
    }
    else{
        mx[p]=mx[rs],se[p]=max(se[rs],mx[ls]);
        cnt[p]=cnt[rs];
    }
    return;
}

inline void build(int p,int l,int r){
    tag[p]=-1;
    if(l==r){
        sum[p]=mx[p]=a[l];
        cnt[p]=1,se[p]=-1;
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid),build(rs,mid+1,r);
    pushup(p);
    return;
}

inline void pushtag(int p,int tg){
    if(mx[p]<=tg)return;
    sum[p]+=(ll)(tg-mx[p])*cnt[p];
    mx[p]=tag[p]=tg;
    return;
}

inline void pushdown(int p){
    if(tag[p]==-1)return;
    pushtag(ls,tag[p]),pushtag(rs,tag[p]);
    tag[p]=-1;
    return;
}

inline void update(int p,int l,int r,int s,int t,int val){
    if(mx[p]<=val)return;
    if(s<=l&&r<=t&&se[p]<val){
        pushtag(p,val);
        return;
    }
    int mid=(l+r)>>1;
    pushdown(p);
    if(s<=mid)update(ls,l,mid,s,t,val);
    if(t>mid)update(rs,mid+1,r,s,t,val);
    pushup(p);
    return;
}

inline int querymax(int p,int l,int r,int s,int t){
    if(s<=l&&r<=t)return mx[p];
    int mid=(l+r)>>1,res=-1;
    pushdown(p);
    if(s<=mid)res=max(res,querymax(ls,l,mid,s,t));
    if(t>mid)res=max(res,querymax(rs,mid+1,r,s,t));
    return res;
}

inline ll querysum(int p,int l,int r,int s,int t){
    if(s<=l&&r<=t)return sum[p];
    int mid=(l+r)>>1;ll res=0;
    pushdown(p);
    if(s<=mid)res+=querysum(ls,l,mid,s,t);
    if(t>mid)res+=querysum(rs,mid+1,r,s,t);
    return res;
}

inline void solve(){
    read(n); read(m);
    for(int i=1;i<=n;i++)read(a[i]);
    build(1,1,n);
    for(int i=1;i<=m;i++){
        int op,l,r,val;
        read(op); read(l); read(r);
        if(!op){
            read(val);
            update(1,1,n,l,r,val);
        }
        else if(op==1)printf("%d\n",querymax(1,1,n,l,r));
        else printf("%lld\n",querysum(1,1,n,l,r));
    }
    return;
}

int main(){
    scanf("%d",&T);
    while(T--)solve();
    return 0;
}

相關文章