原題連結:https://www.luogu.com.cn/problem/P1253
題意解讀:對於一個序列a[n],支援三種操作:1.將區間[l,r]所有數設定為x;2.將區間[l,r]所有數加上x;3.查詢區間[l,r]的最大值
解題思路:典型的線段樹求解區間問題。
線段樹節點需要維護如下關鍵資訊:
1、區間l,r
2、區間最大值v
3、懶標記set,表示將所有子節點對應區間的每個數都設定為set
4、懶標記add,表示將所有子節點對應區間的每個數都加上add
兩個懶標記的關係是,先考慮set,再考慮add,否則結果不對。
接下來,就要解決這些節點資訊如何更新的問題:
1、給節點設定懶標記
對於set x操作,需要將當前節點的最大值v = x, add = 0, set恢復預設值(注意0不是清空set,set要設定成一個不可能取到的值,如INT_MAX)
對於add x操作,需要將當前節點的最大值v += x, add += x, set不變
void addtag(int u, int op, LL x)
{
if(op == 1) //set
{
tr[u].v = x;
tr[u].add = 0;
tr[u].set = x;
}
else //add
{
tr[u].v += x;
tr[u].add += x;
}
}
2、將節點懶標記下傳到子節點
這裡的關鍵在於下傳懶標記的順序,主要要先下傳set標記,再下傳add標記,因為如果先下傳add,後下傳set會將之前增加的值覆蓋掉。
void pushdown(int u)
{
if(tr[u].set != INF)
{
addtag(u << 1, 1, tr[u].set);
addtag(u << 1 | 1, 1, tr[u].set);
tr[u].set = INF;
}
if(tr[u].add)
{
addtag(u << 1, 2, tr[u].add);
addtag(u << 1 | 1, 2, tr[u].add);
tr[u].add = 0;
}
}
剩下的就是注意開long long!
100分程式碼:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1000005, INF = INT_MAX; //INF是懶標記set的預設值
struct Node
{
int l, r;
LL v; //區間[l,r]的最大值
LL set; //懶標記,將所有子節點都設定為set
LL add; //懶標記,將所有子節點都增加add
} tr[N * 4];
LL a[N];
int n, m;
void pushup(int u)
{
tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}
void build(int u, int l, int r)
{
tr[u] = {l, r, 0, INF, 0};
if(l == r) tr[u].v = a[l];
else
{
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void addtag(int u, int op, LL x)
{
if(op == 1) //set
{
tr[u].v = x;
tr[u].add = 0;
tr[u].set = x;
}
else //add
{
tr[u].v += x;
tr[u].add += x;
}
}
void pushdown(int u)
{
if(tr[u].set != INF)
{
addtag(u << 1, 1, tr[u].set);
addtag(u << 1 | 1, 1, tr[u].set);
tr[u].set = INF;
}
if(tr[u].add)
{
addtag(u << 1, 2, tr[u].add);
addtag(u << 1 | 1, 2, tr[u].add);
tr[u].add = 0;
}
}
LL query(int u, int l, int r)
{
if(tr[u].l >= l && tr[u].r <= r) return tr[u].v;
else if(tr[u].l > r || tr[u].r < l) return LLONG_MIN;
else
{
pushdown(u);
return max(query(u << 1, l, r), query(u << 1 | 1, l, r));
}
}
void update(int u, int l, int r, int op, LL x)
{
if(tr[u].l >= l && tr[u].r <= r) addtag(u, op, x);
else if(tr[u].l > r || tr[u].r < l) return;
else
{
pushdown(u);
update(u << 1, l, r, op, x);
update(u << 1 | 1, l, r, op, x);
pushup(u);
}
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) scanf("%lld", &a[i]);
build(1, 1, n);
int op, l, r;
LL x;
while(m--)
{
scanf("%d%d%d", &op, &l, &r);
if(op == 1 || op == 2)
{
scanf("%lld", &x);
update(1, l, r, op, x);
}
else printf("%lld\n", query(1, l, r));
}
return 0;
}