線段樹簡單思路

lulaalu發表於2024-05-09

線段樹

1、瞭解儲存結構

存放左右範圍,和左右子節點的預定義

#define lc p<<1
#define rc p<<1|1
struct node {
    int l, r, sum;
}tr[4 * N];//注意要開4倍,感興趣自己搜

2、遞迴建樹

思路

  1. 對結點的左右範圍賦值
  2. 判斷是不是葉子結點
  3. 建左右子樹
  4. 更新當前結點的sum值
void build(int p,int l,int r) {
    tr[p].l = l, tr[p].r = r;
    if (l == r)return;//表明是葉子結點,不需要再建立
    int m = l + r >> 1;
    build(lc, l, m);
    build(rc, m + 1, r);
    //建完樹後更新當前結點
    tr[p].sum = tr[lc].sum + tr[rc].sum;
}

3、區間查詢

思路

  1. 將區間分裂開,如果被覆蓋,就直接返回,如果沒有,就進行左右子樹的查詢
  2. 如果左子節點的區間與查詢區間有交集,就查詢,右邊同理
int query(int p,int x,int y) {
    //如果區間被包含,直接return
    if (x <= tr[p].l && tr[p].r <= y) 
            return tr[p].sum;
    int m = tr[p].l + tr[p].r >> 1;
    int sum = 0;
    //由m和xy進行對比,判斷是不是與子樹由交集
    if (x <= m) sum += query(lc, x, y);
    if (y > m) sum += query(rc, x, y);
    return sum;
}

4、區間修改

運用懶標記

  • 簡單解釋就是,如果修改的區間完全覆蓋當前的區間,比如對2~7區間每個值都加3

    那麼就是先將當前的p的sum進行修改,sum+=(r-l+1)*3,然後將當前結點打上懶標記add+=3

所以儲存結構要改為:

struct node {
    int l, r, sum, add;
}tr[4 * N]
  • 接上,如果不覆蓋,先將當前懶標記向下傳遞,即pushdown(),然後遞迴更新子區間(類似區間查詢的流程)

    修改完畢後更新當前節點,此時引入pushup函式

void pushup(int p) { tr[p].sum = tr[lc].sum + tr[rc].sum; }
void pushdown(int p) {
    if (!tr[p].add) return;
    int add = tr[p].add;
    tr[lc].sum += add * (tr[lc].r - tr[lc].l + 1);
    tr[rc].sum += add * (tr[rc].r - tr[rc].l + 1);
    tr[lc].add += add;
    tr[rc].add += add;
    tr[p].add = 0;
}
void update(int p,int x,int y,int k) {
    //如果完全覆蓋就直接更新
    if (x <= tr[p].l && tr[p].r <= y) {
        tr[p].sum += k * (tr[p].r - tr[p].l + 1);
        tr[p].add += k;
        return;
    }
    //如果不是完全覆蓋
    //先將add向下傳
    pushdown(p);
    int m = tr[p].l + tr[p].r >> 1;
    //再根據左右子樹的區間重疊情況更新
    if (x <= m)update(lc, x, y, k);
    if (y > m)update(rc, x, y, k);
    //最後更新一下當前節點
    pushup(p);
}

5、完整程式碼

針對例題https://www.luogu.com.cn/problem/P3372

以下是完整程式碼,與上面略有差別

#include<iostream>
using namespace std;
typedef long long ll;
#define endl "\n"
#define IOS ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
const int N = 1e5 + 10;

#define lc p<<1
#define rc p<<1|1
struct node {
    int l, r;
    ll sum, add;
}tr[4 * N];
int arr[N];
void pushup(int p) {tr[p].sum = tr[lc].sum + tr[rc].sum; }
void pushdown(int p) {
    if (!tr[p].add) return;
    ll add = tr[p].add;
    tr[lc].sum += add * (tr[lc].r - tr[lc].l + 1);
    tr[rc].sum += add * (tr[rc].r - tr[rc].l + 1);
    tr[lc].add += add;
    tr[rc].add += add;
    tr[p].add = 0;
}


void build(int p,int l,int r) {
    tr[p] = { l,r,arr[l],0 };
    if (l == r)return;
    int m = l + r >> 1;
    build(lc, l, m);
    build(rc, m + 1, r);
    //建完樹後更新當前結點
    pushup(p);//tr[p].sum = tr[lc].sum + tr[rc].sum;
}

ll query(int p,int x,int y) {
    if (x <= tr[p].l && tr[p].r <= y) 
            return tr[p].sum;

    pushdown(p);//此處先更新一下再查詢
    int m = tr[p].l + tr[p].r >> 1;
    ll sum = 0;
    if (x <= m) sum += query(lc, x, y);
    if (y > m) sum += query(rc, x, y);
    return sum;
}

void update(int p,int x,int y,int k) {
    if (x <= tr[p].l && tr[p].r <= y) {
        tr[p].sum += k * (tr[p].r - tr[p].l + 1);
        tr[p].add += k;
        return;
    }
    pushdown(p);
    int m = tr[p].l + tr[p].r >> 1;
    if (x <= m)update(lc, x, y, k);
    if (y > m)update(rc, x, y, k);
    pushup(p);
}

int n,m,choice,x,y,k;
int main() {
    IOS;
    cin >> n >> m;
    for (int i = 1; i <= n; i++) 
        cin >> arr[i];
    
    build(1, 1, n);

    while (m--) {
        cin >> choice;
        if (choice == 1) {
            cin >> x >> y >> k;
            update(1, x, y, k);
        }
        else {
            cin >> x >> y;
            cout << query(1, x, y) << endl;
        }
    }
    return 0;
}

相關文章