原題連結:https://www.luogu.com.cn/problem/P1471
題意解讀:給定序列a[n],支援三種操作:1.將區間每個數加上一個數 2.查詢區間的平均數 3、查詢區間的方差
解題思路:要支援區間修改和查詢,首選線段樹,下面看線段樹節點需要維護的資訊
平均數 = 區間和 / n,所以第一個要維護的資訊是區間和
再將計算方差的公式展開:
因此,除了維護區間和,還需要維護區間平方和
由於要進行區間修改,需要用到懶標記,定義節點為:
struct Node
{
int l, r;
double sum1; //區間[l,r]的和
double sum2; //區間[l,r]的平方和
double add; //懶標記,表示將所有子節點對應區間每一個數增加add
} tr[N * 4];
由區間和以及區間平方和的定義可知,pushup操作為:
void pushup(Node &root, Node &left, Node &right)
{
root.sum1 = left.sum1 + right.sum1;
root.sum2 = left.sum2 + right.sum2;
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
當執行區間修改操作時,如給區間[l,r]每個數增加k
則有sum1 = sum1 + k * (r-l+1)
而每個數增加k,平方和公式可以展開為:
則有sum2 = sum2 + 2 * k * sum1 + (r-l+1) * k * k
注意,sum2依賴的sum1是沒有加k之前的,所以要先計算sum2,再計算sum1
所以,新增懶標記以及pushdown操作為:
void addtag(int u, double k)
{
tr[u].sum2 += 2 * k * tr[u].sum1 + k * k * (tr[u].r - tr[u].l + 1);
tr[u].sum1 += k * (tr[u].r - tr[u].l + 1);
tr[u].add += k;
}
void pushdown(int u)
{
addtag(u << 1, tr[u].add);
addtag(u << 1 | 1, tr[u].add);
tr[u].add = 0;
}
100分程式碼:
#include <bits/stdc++.h>
using namespace std;
const int N = 100005;
struct Node
{
int l, r;
double sum1; //區間[l,r]的和
double sum2; //區間[l,r]的平方和
double add; //懶標記,表示將所有子節點對應區間每一個數增加add
} tr[N * 4];
double a[N];
int n, m;
void pushup(Node &root, Node &left, Node &right)
{
root.sum1 = left.sum1 + right.sum1;
root.sum2 = left.sum2 + right.sum2;
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
tr[u] = {l, r};
if(l == r) tr[u].sum1 = a[l], tr[u].sum2 = a[l] * a[l];
else
{
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void addtag(int u, double k)
{
tr[u].sum2 += 2 * k * tr[u].sum1 + k * k * (tr[u].r - tr[u].l + 1);
tr[u].sum1 += k * (tr[u].r - tr[u].l + 1);
tr[u].add += k;
}
void pushdown(int u)
{
addtag(u << 1, tr[u].add);
addtag(u << 1 | 1, tr[u].add);
tr[u].add = 0;
}
Node query(int u, int l, int r)
{
if(tr[u].l >= l && tr[u].r <= r) return tr[u];
else if(tr[u].l > r || tr[u].r < l) return Node{};
else
{
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
Node res = {};
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
pushup(res, left, right);
return res;
}
}
void update(int u, int l, int r, double k)
{
if(tr[u].l >= l && tr[u].r <= r) addtag(u, k);
else if(tr[u].l > r || tr[u].r < l) return;
else
{
pushdown(u);
update(u << 1, l, r, k);
update(u << 1 | 1, l, r, k);
pushup(u);
}
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
int op, x, y;
double k; //非常關鍵,注意如果定義int將得0分
while(m--)
{
cin >> op >> x >> y;
if(op == 1)
{
cin >> k;
update(1, x, y, k);
}
else if(op == 2) cout << fixed << setprecision(4) << query(1, x, y).sum1 / (y - x + 1) << endl;
else
{
Node res = query(1, x, y);
n = y - x + 1;
cout << fixed << setprecision(4) << res.sum2 / n - res.sum1 * res.sum1 / n / n << endl;
}
}
return 0;
}