LOJ #2005. 「SDOI2017」相關分析 線段樹維護迴歸直線方程

liuchanglc發表於2020-11-18

題目描述

\(Frank\) 對天文學非常感興趣,他經常用望遠鏡看星星,同時記錄下它們的資訊,比如亮度、顏色等等,進而估算出星星的距離,半徑等等。

\(Frank\) 不僅喜歡觀測,還喜歡分析觀測到的資料。他經常分析兩個引數之間(比如亮度和半徑)是否存在某種關係。

現在 \(Frank\) 要分析引數 \(X\)\(Y\) 之間的關係。他有 \(n\) 組觀測資料,第 \(i\) 組觀測資料記錄了 \(x_i\)\(y_i\)​。他需要一下幾種操作

\(1\ L,R:\)

用直線擬合第 \(L\) 組到第 \(R\) 組觀測資料。用 \(\overline{x}\) 表示這些觀測資料中 \(x\) 的平均數,用 \(\overline{y}\) ​表示這些觀測資料中 \(y\) 的平均數,即

\(\overline{x}={1 \over R-L+1} \sum _{i=L} ^R x_i\)

\(\overline{y}={1 \over R-L+1} \sum _{i=L} ^R y_i\)

如果直線方程是 \(y=ax+b\),那麼 \(a,b\) 應當這樣計算:

\(a={\sum_{i=L} ^R (x_i-\overline{x})(y_i-\overline{y}) \over \sum _{i=L} ^R (x_i -\overline{x})^2}\)

你需要幫助 \(Frank\) 計算 \(a\)

\(2\ L,R,S,T:\)

\(Frank\) 發現測量資料第 \(L\) 組到第 \(R\) 組資料有誤差,對每個 \(i\) 滿足 \(L \leq i \leq R\)\(x_i\) ​需要加上 \(S\)\(y_i\) ​需要加上\(T\)

\(3\ L,R,S,T:\)

\(Frank\)發現第 \(L\) 組到第 \(R\) 組資料需要修改,對於每個 \(i\) 滿足 \(L \leq i \leq R\)\(x_i\)​需要修改為 \((S+i)\)\(y_i\) ​需要修改為 \((T+i)\)
輸入格式

第一行兩個數 \(n,m\),表示觀測資料組數和操作次數。

接下來一行 \(n\) 個數,第 \(i\) 個數是 \(x_i\)​。

接下來一行 \(n\) 個數,第 \(i\) 個數是 \(y_i\)​。

接下來 \(m\) 行,表示操作,格式見題目描述。

輸出格式

對於每個 \(1\) 操作,輸出一行,表示直線斜率 \(a\)。選手輸出與標準輸出的絕對誤差或相對誤差不超過 \(10^{-5}\) 即為正確。

輸入輸出樣例

輸入 #1

3 5
1 2 3
1 2 3
1 1 3
2 2 3 -3 2
1 1 2
3 1 2 2 1
1 1 3

輸出 #1

1.0000000000
-1.5000000000
-0.6153846154

說明/提示

對於 \(20\%\) 的資料 \(1 \leq n,m \leq 1000\)

另有 \(20\%\) 的資料,沒有 \(3\) 操作,且 \(2\) 操作中 \(S=0\)

另有 \(30\%\) 的資料,沒有 \(3\) 操作。

對於 \(100\%\) 的資料,\(1 \leq n,m \leq 10^5,0 \leq |S|,|T| \leq 10^5,0 \leq |x_i|,|y_i| \leq 10^5\)

保證 \(1\) 操作不會出現分母為 \(0\) 的情況。

時間限制:\(1s\)

空間限制:\(128MB\)

分析

把式子化簡,就會得到

\(\begin{aligned} & \sum (x_i - \bar x)(y_i - \bar y) \\ = & \sum (x_i y_i - x_i \bar y - y_i \bar x_i + \bar x \bar y) \\ = & \sum x_i y_i - \bar y\sum x_i - \bar x\sum y_i + n\bar x \bar y \\ = & \sum x_i y_i - n\bar x \bar y \\ \\ & \sum (x_i - \bar x)^2 \\ = & \sum (x_i^2 + {\bar x}^2 - 2x_i\bar x) \\ = & \sum x_i^2 +n{\bar x}^2 - 2\bar x\sum x_i \\ = & \sum x_i^2 -n {\bar x}^2\\ \end{aligned}\)

那麼我們要維護的東西就是 \(x_i\)\(y_i\)\(x_iy_i\)

對於操作 \(2\)

\(\begin{aligned} &\sum x_i\to\sum(x_i+S)=\sum x_i+nS \\ &\sum y_i\to\sum(y_i+T)=\sum y_i+nT \\ &\sum x_i^2\to\sum(x_i+S)^2=\sum x_i^2+nS^2+2S\sum x_i\\ &\sum x_i y_i \to\sum(x_i+S)(y_i+T)=\sum x_i y_i+T\sum x_i+S\sum y_i+nST \\ \end{aligned}\)

對於操作 \(3\)

\(\begin{aligned} &\sum x_i\to\sum(i+S)=s_1+nS \\ &\sum y_i\to\sum(i+T)=s_1+nT \\ &\sum x_i^2\to\sum(i+S)^2=s_2+nS^2+2Ss_1\\ &\sum x_i y_i \to\sum(i+S)(i+T)=s_2+(T+S)s_1+nST\\ \end{aligned}\)

其中 \(s_1\) 是等差數列的求和公式 \(\frac{n(n+1)}{2}\)

\(s_2\)\(i^2\) 的字首和 \(\frac{n(n+1)(2n+1)}{6}\)

注意下放標記的時候只要有一個不為零就要下放

要先下放覆蓋的標記,再下放加的標記

程式碼

#include <cstdio>
#include <algorithm>
#include <cmath>
#define rg register
const int maxn = 1e5 + 5;
typedef double db;
int n, m;
db jlx[maxn], jly[maxn];
struct trr {
    int l, r, siz;
    db sumx, sumy, sumxx, sumxy, lazx, lazy, tagx, tagy;
    trr() {
        tagx = tagy = 1e18;
        sumx = sumy = sumxx = sumxy = lazx = lazy = 0;
        l = r = siz = 0;
    }
} tr[maxn << 2];
db getsum1(int l, int r) { return (db)(r - l + 1.0) * (l + r) / 2.0; }
db getsum2(int r) { return (db)r * (r + 1.0) * (2.0 * r + 1.0) / 6.0; }
void push_up(int da) {
    tr[da].sumx = tr[da << 1].sumx + tr[da << 1 | 1].sumx;
    tr[da].sumy = tr[da << 1].sumy + tr[da << 1 | 1].sumy;
    tr[da].sumxx = tr[da << 1].sumxx + tr[da << 1 | 1].sumxx;
    tr[da].sumxy = tr[da << 1].sumxy + tr[da << 1 | 1].sumxy;
}
void push_down(int da) {
    if (tr[da].tagx != 1e18 || tr[da].tagy != 1e18) {
        tr[da << 1].tagx = tr[da].tagx;
        tr[da << 1 | 1].tagx = tr[da].tagx;
        tr[da << 1].tagy = tr[da].tagy;
        tr[da << 1 | 1].tagy = tr[da].tagy;
        tr[da << 1].sumx = tr[da].tagx * tr[da << 1].siz + getsum1(tr[da << 1].l, tr[da << 1].r);
        tr[da << 1 | 1].sumx =
            tr[da].tagx * tr[da << 1 | 1].siz + getsum1(tr[da << 1 | 1].l, tr[da << 1 | 1].r);
        tr[da << 1].sumy = tr[da].tagy * tr[da << 1].siz + getsum1(tr[da << 1].l, tr[da << 1].r);
        tr[da << 1 | 1].sumy =
            tr[da].tagy * tr[da << 1 | 1].siz + getsum1(tr[da << 1 | 1].l, tr[da << 1 | 1].r);
        tr[da << 1].sumxx = tr[da << 1].siz * tr[da].tagx * tr[da].tagx +
                            2.0 * tr[da].tagx * getsum1(tr[da << 1].l, tr[da << 1].r) +
                            getsum2(tr[da << 1].r) - getsum2(tr[da << 1].l - 1);
        tr[da << 1 | 1].sumxx = tr[da << 1 | 1].siz * tr[da].tagx * tr[da].tagx +
                                2.0 * tr[da].tagx * getsum1(tr[da << 1 | 1].l, tr[da << 1 | 1].r) +
                                getsum2(tr[da << 1 | 1].r) - getsum2(tr[da << 1 | 1].l - 1);
        tr[da << 1].sumxy = tr[da << 1].siz * tr[da].tagx * tr[da].tagy +
                            (tr[da].tagx + tr[da].tagy) * getsum1(tr[da << 1].l, tr[da << 1].r) +
                            getsum2(tr[da << 1].r) - getsum2(tr[da << 1].l - 1);
        tr[da << 1 | 1].sumxy = tr[da << 1 | 1].siz * tr[da].tagx * tr[da].tagy +
                                (tr[da].tagx + tr[da].tagy) * getsum1(tr[da << 1 | 1].l, tr[da << 1 | 1].r) +
                                getsum2(tr[da << 1 | 1].r) - getsum2(tr[da << 1 | 1].l - 1);
        tr[da].tagx = tr[da].tagy = 1e18;
        tr[da << 1].lazx = tr[da << 1 | 1].lazx = tr[da << 1].lazy = tr[da << 1 | 1].lazy = 0;
    }
    if (tr[da].lazx != 0 || tr[da].lazy != 0) {
        tr[da << 1].lazx += tr[da].lazx;
        tr[da << 1 | 1].lazx += tr[da].lazx;
        tr[da << 1].lazy += tr[da].lazy;
        tr[da << 1 | 1].lazy += tr[da].lazy;
        tr[da << 1].sumxx +=
            2.0 * tr[da].lazx * tr[da << 1].sumx + tr[da << 1].siz * tr[da].lazx * tr[da].lazx;
        tr[da << 1 | 1].sumxx +=
            2.0 * tr[da].lazx * tr[da << 1 | 1].sumx + tr[da << 1 | 1].siz * tr[da].lazx * tr[da].lazx;
        tr[da << 1].sumxy += tr[da << 1].sumx * tr[da].lazy + tr[da << 1].sumy * tr[da].lazx +
                             tr[da << 1].siz * tr[da].lazx * tr[da].lazy;
        tr[da << 1 | 1].sumxy += tr[da << 1 | 1].sumx * tr[da].lazy + tr[da << 1 | 1].sumy * tr[da].lazx +
                                 tr[da << 1 | 1].siz * tr[da].lazx * tr[da].lazy;
        tr[da << 1].sumx += tr[da << 1].siz * tr[da].lazx;
        tr[da << 1 | 1].sumx += tr[da << 1 | 1].siz * tr[da].lazx;
        tr[da << 1].sumy += tr[da << 1].siz * tr[da].lazy;
        tr[da << 1 | 1].sumy += tr[da << 1 | 1].siz * tr[da].lazy;
        tr[da].lazx = tr[da].lazy = 0;
    }
}
void build(int da, int l, int r) {
    tr[da].l = l, tr[da].r = r, tr[da].siz = r - l + 1;
    if (tr[da].l == tr[da].r) {
        tr[da].sumx = jlx[l];
        tr[da].sumy = jly[l];
        tr[da].sumxx = jlx[l] * jlx[l];
        tr[da].sumxy = jlx[l] * jly[l];
        return;
    }
    rg int mids = (tr[da].l + tr[da].r) >> 1;
    build(da << 1, l, mids);
    build(da << 1 | 1, mids + 1, r);
    push_up(da);
}
void ad(int da, int l, int r, db valx, db valy) {
    if (tr[da].l >= l && tr[da].r <= r) {
        tr[da].lazx += valx;
        tr[da].lazy += valy;
        tr[da].sumxx += 2.0 * valx * tr[da].sumx + tr[da].siz * valx * valx;
        tr[da].sumxy += tr[da].sumx * valy + tr[da].sumy * valx + tr[da].siz * valx * valy;
        tr[da].sumx += tr[da].siz * valx;
        tr[da].sumy += tr[da].siz * valy;
        return;
    }
    push_down(da);
    rg int mids = (tr[da].l + tr[da].r) >> 1;
    if (l <= mids)
        ad(da << 1, l, r, valx, valy);
    if (r > mids)
        ad(da << 1 | 1, l, r, valx, valy);
    push_up(da);
}
void xg(int da, int l, int r, db valx, db valy) {
    if (tr[da].l >= l && tr[da].r <= r) {
        tr[da].lazx = 0, tr[da].lazy = 0;
        tr[da].tagx = valx;
        tr[da].tagy = valy;
        tr[da].sumx = valx * tr[da].siz + getsum1(tr[da].l, tr[da].r);
        tr[da].sumy = valy * tr[da].siz + getsum1(tr[da].l, tr[da].r);
        tr[da].sumxx = tr[da].siz * valx * valx + 2.0 * valx * getsum1(tr[da].l, tr[da].r) +
                       getsum2(tr[da].r) - getsum2(tr[da].l - 1);
        tr[da].sumxy = tr[da].siz * valx * valy + (valx + valy) * getsum1(tr[da].l, tr[da].r) +
                       getsum2(tr[da].r) - getsum2(tr[da].l - 1);
        return;
    }
    push_down(da);
    rg int mids = (tr[da].l + tr[da].r) >> 1;
    if (l <= mids)
        xg(da << 1, l, r, valx, valy);
    if (r > mids)
        xg(da << 1 | 1, l, r, valx, valy);
    push_up(da);
}
db cxx(int da, int l, int r) {
    if (tr[da].l >= l && tr[da].r <= r) {
        return tr[da].sumx;
    }
    push_down(da);
    rg int mids = (tr[da].l + tr[da].r) >> 1;
    rg db nans = 0;
    if (l <= mids)
        nans += cxx(da << 1, l, r);
    if (r > mids)
        nans += cxx(da << 1 | 1, l, r);
    return nans;
}
db cxy(int da, int l, int r) {
    if (tr[da].l >= l && tr[da].r <= r) {
        return tr[da].sumy;
    }
    push_down(da);
    rg int mids = (tr[da].l + tr[da].r) >> 1;
    rg db nans = 0;
    if (l <= mids)
        nans += cxy(da << 1, l, r);
    if (r > mids)
        nans += cxy(da << 1 | 1, l, r);
    return nans;
}
db cxxx(int da, int l, int r) {
    if (tr[da].l >= l && tr[da].r <= r) {
        return tr[da].sumxx;
    }
    push_down(da);
    rg int mids = (tr[da].l + tr[da].r) >> 1;
    rg db nans = 0;
    if (l <= mids)
        nans += cxxx(da << 1, l, r);
    if (r > mids)
        nans += cxxx(da << 1 | 1, l, r);
    return nans;
}
db cxxy(int da, int l, int r) {
    if (tr[da].l >= l && tr[da].r <= r) {
        return tr[da].sumxy;
    }
    push_down(da);
    rg int mids = (tr[da].l + tr[da].r) >> 1;
    rg db nans = 0;
    if (l <= mids)
        nans += cxxy(da << 1, l, r);
    if (r > mids)
        nans += cxxy(da << 1 | 1, l, r);
    return nans;
}
db getx(int l, int r) { return (db)cxx(1, l, r) / (r - l + 1); }
db gety(int l, int r) { return (db)cxy(1, l, r) / (r - l + 1); }
void solve(int l, int r) {
    db ans1 = cxxy(1, l, r) - (db)(r - l + 1) * getx(l, r) * gety(l, r);
    db ans2 = cxxx(1, l, r) - (db)(r - l + 1) * getx(l, r) * getx(l, r);
    printf("%.10f\n", ans1 / ans2);
}
int main() {
    scanf("%d%d", &n, &m);
    for (rg int i = 1; i <= n; i++) {
        scanf("%lf", &jlx[i]);
    }
    for (rg int i = 1; i <= n; i++) {
        scanf("%lf", &jly[i]);
    }
    build(1, 1, n);
    rg int aa, bb, cc;
    db dd, ee;
    for (rg int i = 1; i <= m; i++) {
        scanf("%d%d%d", &aa, &bb, &cc);
        if (aa == 1) {
            solve(bb, cc);
        } else if (aa == 2) {
            scanf("%lf%lf", &dd, &ee);
            ad(1, bb, cc, dd, ee);
        } else {
            scanf("%lf%lf", &dd, &ee);
            xg(1, bb, cc, dd, ee);
        }
    }
    return 0;
}

相關文章