洛谷題單指南-線段樹-P1471 方差

五月江城發表於2024-12-05

原題連結:https://www.luogu.com.cn/problem/P1471

題意解讀:給定序列a[n],支援三種操作:1.將區間每個數加上一個數 2.查詢區間的平均數 3、查詢區間的方差

解題思路:要支援區間修改和查詢,首選線段樹,下面看線段樹節點需要維護的資訊

平均數 = 區間和 / n,所以第一個要維護的資訊是區間和

再將計算方差的公式展開:

洛谷題單指南-線段樹-P1471 方差

因此,除了維護區間和,還需要維護區間平方和

由於要進行區間修改,需要用到懶標記,定義節點為:

struct  Node
{
    int l, r;
    double sum1; //區間[l,r]的和
    double sum2; //區間[l,r]的平方和
    double add; //懶標記,表示將所有子節點對應區間每一個數增加add
} tr[N * 4];

由區間和以及區間平方和的定義可知,pushup操作為:

void pushup(Node &root, Node &left, Node &right)
{
    root.sum1 = left.sum1 + right.sum1;
    root.sum2 = left.sum2 + right.sum2;
}

void pushup(int u)
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

當執行區間修改操作時,如給區間[l,r]每個數增加k

則有sum1 = sum1 + k * (r-l+1)

而每個數增加k,平方和公式可以展開為:

洛谷題單指南-線段樹-P1471 方差

則有sum2 = sum2 + 2 * k * sum1 + (r-l+1) * k * k

注意,sum2依賴的sum1是沒有加k之前的,所以要先計算sum2,再計算sum1

所以,新增懶標記以及pushdown操作為:

void addtag(int u, double k)
{
    tr[u].sum2 += 2 * k * tr[u].sum1 + k * k * (tr[u].r - tr[u].l + 1);
    tr[u].sum1 += k * (tr[u].r - tr[u].l + 1);
    tr[u].add += k;
}

void pushdown(int u)
{
    addtag(u << 1, tr[u].add);
    addtag(u << 1 | 1, tr[u].add);
    tr[u].add = 0;
}

100分程式碼:

#include <bits/stdc++.h>
using namespace std;

const int N = 100005;

struct  Node
{
    int l, r;
    double sum1; //區間[l,r]的和
    double sum2; //區間[l,r]的平方和
    double add; //懶標記,表示將所有子節點對應區間每一個數增加add
} tr[N * 4];
double a[N];
int n, m;

void pushup(Node &root, Node &left, Node &right)
{
    root.sum1 = left.sum1 + right.sum1;
    root.sum2 = left.sum2 + right.sum2;
}

void pushup(int u)
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r)
{
    tr[u] = {l, r};
    if(l == r) tr[u].sum1 = a[l], tr[u].sum2 = a[l] * a[l];
    else 
    {
        int mid = l + r >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void addtag(int u, double k)
{
    tr[u].sum2 += 2 * k * tr[u].sum1 + k * k * (tr[u].r - tr[u].l + 1);
    tr[u].sum1 += k * (tr[u].r - tr[u].l + 1);
    tr[u].add += k;
}

void pushdown(int u)
{
    addtag(u << 1, tr[u].add);
    addtag(u << 1 | 1, tr[u].add);
    tr[u].add = 0;
}

Node query(int u, int l, int r)
{
    if(tr[u].l >= l && tr[u].r <= r) return tr[u];
    else if(tr[u].l > r || tr[u].r < l) return Node{};
    else
    {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        Node res = {};
        Node left = query(u << 1, l, r);
        Node right = query(u << 1 | 1, l, r);
        pushup(res, left, right);
        return res;
    }
}

void update(int u, int l, int r, double k)
{
    if(tr[u].l >= l && tr[u].r <= r) addtag(u, k);
    else if(tr[u].l > r || tr[u].r < l) return;
    else 
    {
        pushdown(u);
        update(u << 1, l, r, k);
        update(u << 1 | 1, l, r, k);
        pushup(u);
    }
}

int main()
{
    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> a[i];
    build(1, 1, n);
    int op, x, y;
    double k; //非常關鍵,注意如果定義int將得0分
    while(m--)
    {
        cin >> op >> x >> y;
        if(op == 1)
        {
            cin >> k;
            update(1, x, y, k);
        }
        else if(op == 2) cout << fixed << setprecision(4) << query(1, x, y).sum1 / (y - x + 1) << endl;
        else 
        {
            Node res = query(1, x, y);
            n = y - x + 1;
            cout << fixed << setprecision(4) << res.sum2 / n - res.sum1 * res.sum1 / n / n << endl;
        }
    }
    return 0;
}

相關文章