線段樹入門(Segment Tree)
基本線段樹
與樹狀陣列功能類似,實現了點的修改與區間的查詢:
首先實現基本的線段樹的構建:
#include <iostream>
#include <vector>
using namespace std;
class segmentTree{
public:
segmentTree(int n, vector<int> nums){
size = 4 * n;
tree.resize(size);
this->nums = nums;
}
void build(int l, int r, int i){ // i indexed from 1.
if (l == r) {
tree[i] = nums[l];
return;
}
int m = l + (r - l) / 2;
build(l, m, i * 2);
build(m+1, r, 2 * i + 1);
tree[i] = tree[i * 2] + tree[i * 2 + 1];
}
vector<int> getTree(){
return tree;
}
private:
int size;
vector<int> nums; //
vector<int> tree;
};
int main(){
int n = 4;
vector<int> nums{1,2,3,4};
auto st = segmentTree(n, nums);
st.build(0, n-1, 1);
vector<int> tree = st.getTree();
for (int i=0; i<tree.size(); i++) {
cout<< tree[i]<<" ";
}
return 0;
}
單點修改:
單點修改有兩種,增量式修改,即加上某值 (記為 add 方法) nums[i]+=x 或 覆蓋式修改,即改為某值 (記為 update 方法)
#include <iostream>
#include <vector>
using namespace std;
class segmentTree{
public:
segmentTree(int n, vector<int> nums){
this->n = n;
size = 4 * n;
tree.resize(size);
this->nums = nums;
build(0, n-1, 1);
}
void add(int x, int index){ // 對陣列中的index加x
add_p(x, index, 0, n-1, 1);
}
void update(int x, int index) {
update_p(x, index, 0, n-1, 1);
}
vector<int> getTree(){
return tree;
}
void printTree(){
for (int i=0; i<tree.size(); i++) {
cout<< tree[i]<<" ";
}
cout<<endl;
}
private:
int n;
int size;
vector<int> nums; //
vector<int> tree;
void build(int l, int r, int i){ // i indexed from 1.
if (l == r) {
tree[i] = nums[l];
return;
}
int m = l + (r - l) / 2;
build(l, m, i * 2);
build(m+1, r, 2 * i + 1);
tree[i] = tree[i * 2] + tree[i * 2 + 1];
}
void add_p(int x, int index, int l, int r, int i) {
if (l == r) {
tree[i] += x;
return;
}
int m = l + (r-l) / 2;
if (index <= m)
add_p(x, index, l, m, i*2);
else
add_p(x, index, m+1, r, i*2+1);
tree[i] = tree[i * 2] + tree[i * 2 + 1];
}
void update_p(int x, int index, int l, int r, int i) {
if (l == r) {
tree[i] = x;
return;
}
int m = l + (r-l) / 2;
if (index <= m)
add_p(x, index, l, m, i*2);
else
add_p(x, index, m+1, r, i*2+1);
tree[i] = tree[i * 2] + tree[i * 2 + 1];
}
};
int main(){
int n = 4;
vector<int> nums{1,2,3,4};
auto st = segmentTree(n, nums);
st.printTree();
st.add(5, 0);
st.printTree();
st.update(100, 1);
st.printTree();
return 0;
}
離散化方法:
先複製一個陣列進行排序,然後透過二分查詢的方法定位到元素。注意distance計算的是相隔多少個元素,lower_bound返回的是是大於等於指定元素的迭代器位置。
void discrete(std::vector<int>& nums) {
int n = nums.size();
std::vector<int> tmp(nums);
std::sort(tmp.begin(), tmp.end());
for (int i = 0; i < n; ++i) {
nums[i] = std::distance(tmp.begin(), std::lower_bound(tmp.begin(), tmp.end(), nums[i])) + 1;
}
}
動態線段樹:Leetcode 699
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
class Solution {
public:
int N = (int)1e9;
struct Node{
Node* ls;
Node* rs;
int val;
int lazy; // 懶標記
Node(){val=0;lazy=0;ls=nullptr; rs=nullptr;}
};
int query_p(int index1, int index2, int l, int r, Node* cur){ // 查詢區間上的最高值
if (index1 <= l && r <= index2) {
return cur->val;
}
pushdown(cur);
int m = l + (r - l) / 2;
int left=0, right=0;
if (index1 <= m) {
left = query_p(index1, index2, l, m, cur->ls);
}
if (index2 > m) {
right = query_p(index1, index2, m+1, r, cur->rs);
}
return max(left, right);
}
void update(Node* node, int index1, int index2, int l, int r, int h){
if (index1 <= l && r <= index2) {
node->lazy = h;
node->val = h;
return;
}
pushdown(node);
int m = l + (r - l) / 2;
if (index1 <= m) update(node->ls, index1, index2, l, m, h);
if (index2 > m) update(node->rs, index1, index2, m+1, r, h);
pushup(node);
}
void pushdown(Node* cur){
if (!cur->ls) cur->ls = new Node();
if (!cur->rs) cur->rs = new Node();
if (cur->lazy == 0) {
return;
}
cur->ls->lazy = cur->lazy;
cur->rs->lazy = cur->lazy;
cur->ls->val = cur->lazy;
cur->rs->val = cur->lazy;
cur->lazy = 0;
}
void pushup(Node* node) {
node->val = max(node->ls->val, node->rs->val);
}
vector<int> fallingSquares(vector<vector<int>>& positions) {
vector<int> ans;
Node* root = new Node();
for (auto pos: positions) {
int x = pos[0], h=pos[1];
int cur = query_p(x, x+h-1, 0, N, root);
update(root, x, x+h-1, 0, N, cur+h);
ans.push_back(root->val);
}
return ans;
}
};