線段樹入門(Segment Tree)

石中火本火發表於2024-06-02

線段樹入門(Segment Tree)

基本線段樹

與樹狀陣列功能類似,實現了點的修改與區間的查詢:

首先實現基本的線段樹的構建:

#include <iostream>
#include <vector>
using namespace std;

class segmentTree{
public:
    segmentTree(int n, vector<int> nums){
        size = 4 * n;
        tree.resize(size);
        this->nums = nums;
    }
    void build(int l, int r, int i){ // i indexed from 1.
        if (l == r) {
            tree[i] = nums[l];
            return;
        }
        int m = l + (r - l) / 2;
        build(l, m, i * 2);
        build(m+1, r, 2 * i + 1);
        tree[i] = tree[i * 2] + tree[i * 2 + 1];
    }
    vector<int> getTree(){
        return tree;
    }
private:
    int size;
    vector<int> nums; // 
    vector<int> tree;
};

int main(){
    
    int n = 4;
    vector<int> nums{1,2,3,4};
    auto st = segmentTree(n, nums);
    st.build(0, n-1, 1);
    vector<int> tree = st.getTree();
    for (int i=0; i<tree.size(); i++) {
        cout<< tree[i]<<" ";
    }
    return 0;
}

單點修改:

單點修改有兩種,增量式修改,即加上某值 (記為 add 方法) nums[i]+=x 或 覆蓋式修改,即改為某值 (記為 update 方法)

#include <iostream>
#include <vector>
using namespace std;

class segmentTree{
public:
    segmentTree(int n, vector<int> nums){
        this->n = n;
        size = 4 * n;
        tree.resize(size);
        this->nums = nums;
        build(0, n-1, 1);
    }
    
    void add(int x, int index){ // 對陣列中的index加x
        add_p(x, index, 0, n-1, 1);
    }
    void update(int x, int index) {
        update_p(x, index, 0, n-1, 1);
    }
    vector<int> getTree(){
        return tree;
    }
    void printTree(){
        for (int i=0; i<tree.size(); i++) {
            cout<< tree[i]<<" ";
        }
        cout<<endl;
    }
private:
    int n;
    int size;
    vector<int> nums; // 
    vector<int> tree;
    void build(int l, int r, int i){ // i indexed from 1.
        if (l == r) {
            tree[i] = nums[l];
            return;
        }
        int m = l + (r - l) / 2;
        build(l, m, i * 2);
        build(m+1, r, 2 * i + 1);
        tree[i] = tree[i * 2] + tree[i * 2 + 1];
    }
    void add_p(int x, int index, int l, int r, int i) {
        if (l == r) {
            tree[i] += x;
            return;
        }
        int m = l + (r-l) / 2;
        if (index <= m)
            add_p(x, index, l, m, i*2);
        else
            add_p(x, index, m+1, r, i*2+1);
        tree[i] = tree[i * 2] + tree[i * 2 + 1];
    }
    void update_p(int x, int index, int l, int r, int i) {
        if (l == r) {
            tree[i] = x;
            return;
        }
        int m = l + (r-l) / 2;
        if (index <= m)
            add_p(x, index, l, m, i*2);
        else
            add_p(x, index, m+1, r, i*2+1);
        tree[i] = tree[i * 2] + tree[i * 2 + 1];
    }
};

int main(){
    
    int n = 4;
    vector<int> nums{1,2,3,4};
    auto st = segmentTree(n, nums);
    st.printTree();
    st.add(5, 0);
    st.printTree();
    st.update(100, 1);
    st.printTree();
    return 0;
}

離散化方法:

先複製一個陣列進行排序,然後透過二分查詢的方法定位到元素。注意distance計算的是相隔多少個元素,lower_bound返回的是是大於等於指定元素的迭代器位置。

void discrete(std::vector<int>& nums) {
    int n = nums.size();
    std::vector<int> tmp(nums);
    std::sort(tmp.begin(), tmp.end());
    for (int i = 0; i < n; ++i) {
        nums[i] = std::distance(tmp.begin(), std::lower_bound(tmp.begin(), tmp.end(), nums[i])) + 1;
    }
}

動態線段樹Leetcode 699

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

class Solution {
public:
    int N = (int)1e9;
    struct Node{
        Node* ls;
        Node* rs;
        int val;
        int lazy; // 懶標記
        Node(){val=0;lazy=0;ls=nullptr; rs=nullptr;}
    };
    int query_p(int index1, int index2, int l, int r, Node* cur){ // 查詢區間上的最高值
        if (index1 <= l && r <= index2) {
            return cur->val;
        }
        pushdown(cur);
        int m = l + (r - l) / 2;
        int left=0, right=0;
        if (index1 <= m) {
            left = query_p(index1, index2, l, m, cur->ls);
        }
        if (index2 > m) {
            right = query_p(index1, index2, m+1, r, cur->rs);
        }
        return max(left, right);
    }
    void update(Node* node, int index1, int index2, int l, int r, int h){
        if (index1 <= l && r <= index2) {
            node->lazy = h;
            node->val = h;
            return;
        }
        pushdown(node);
        int m = l + (r - l) / 2;
        if (index1 <= m) update(node->ls, index1, index2, l, m, h);
        if (index2 > m) update(node->rs, index1, index2, m+1, r, h);
        pushup(node);
    }
    void pushdown(Node* cur){
        if (!cur->ls) cur->ls = new Node();
        if (!cur->rs) cur->rs = new Node();
        if (cur->lazy == 0) {
            return;
        }
        cur->ls->lazy = cur->lazy;
        cur->rs->lazy = cur->lazy;
        cur->ls->val  = cur->lazy;
        cur->rs->val = cur->lazy;
        cur->lazy = 0;
    }
    void pushup(Node* node) {
        node->val = max(node->ls->val, node->rs->val);
    }
    vector<int> fallingSquares(vector<vector<int>>& positions) {
        vector<int> ans;
        Node* root = new Node();
        for (auto pos: positions) {
            int x = pos[0], h=pos[1];
            int cur = query_p(x, x+h-1, 0, N, root);
            update(root, x, x+h-1, 0, N, cur+h);
            ans.push_back(root->val);
        }
        return ans;
    }
};

相關文章