樹狀陣列(待補)(生硬 公式 用法 證明)

blind5883發表於2024-11-25

理解之後挺簡單的。開始比較抽象,理解後比較簡單,不如直接看程式碼。

樹狀陣列是一個兼顧修改和查詢的資料結構。一般可以支援,\(O(\log n)\)區間查詢單點修改。如果原陣列為差分陣列,可以實現區間修改單點查詢。

因為常數小,在某些情況下比較好用,但一般它能做到的,線段樹都能做到。

中心思想

關於樹狀陣列的一切因為涉及 lowbit 比較抽象。

核心就是 \(tr\) 陣列,\(tr[i]\) 陣列代表原陣列 \(a\)\(i\) 為結尾長度為 \(\operatorname {lowbit}(i)\) 的下標區間和,也就是
\(a[i - \operatorname {lowbit}(i) + 1, i]\)

我們實質上是把這個陣列拆分成了一個樹,但計算原理執行上和樹沒什麼關係。只是呈樹形關係。關於原理見這篇文章 讓你頓悟樹狀陣列原理與由來 - 知乎 (zhihu.com)

查詢

我們知道一個數有二進位制的形式,如 \(11 = 1011\)。如果我們求原陣列下標 \([1, 11]\) 的和,可以把這個區間拆分,把每個拆分割槽間提前算出來,相當於預處理的方式,減少時間複雜度。

樹狀陣列就是按照 tr 陣列的形式來拆分了區間,像 \([1, 11]\) 就可以用(二進位制下)\(tr[1011],tr[1010],tr[1000]\) 這三個給湊出來,比起原來相加了 11 次,這隻加了 3 次,大大最佳化了時間。

對於一個區間可以不斷先加上 \(tr[\operatorname {lowbit}(i)]\),再把 i 減去 lowbit(i),這樣湊出。如 11 = 1011,有以下關係:

\[\begin{aligned} 1011 &=[1, 1011] \\ tr[1011] &= [1011,1011] \\ tr[1010] &= [1001, 1010] \\ tr[1000] &= [1, 1000] \\ \end{aligned}\]

\(tr[1011],tr[1010],tr[1000]\) 恰好為 \(tr[11],tr[11- \operatorname {lowbit}(11)] =tr[10],tr[10 -\operatorname {lowbit}(10)] = tr[8]\) 恰好符合上面的執行過程。

如果想求任意區間,方法和字首和求區間方式一樣。

int sum(int x) // 求出為第1 - x的和 sum/query
{
    int res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

關係

\(tr\) 陣列之間有關係,即一個數不斷相加當前的 lowbit 就可以找到所有包含它的 tr 陣列,即它可以影響的陣列,即父節點;同理透過不斷減當前的 lowbit 就可以所有找到影響它的 tr 陣列,即子節點。這個關係也就是稱為樹狀陣列的原因。

證明待補。自己想想也簡單。

修改

如果第 i 個數加上了 k,那麼所有包含 i 的 tr 都應該加上 k,可寫出。

void add(int x, int k) // 讓第i個數加上k
{
    for (int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}

求出 tr 陣列

我們求出的方式以 \(tr[10100]\) 舉例

tr[10100] = [10001,10100]

10011 = 10100 - 1
10010 = 10011 - lowbit(10011) = 10011 - 1
10010 - lowbit(10010) = 10010 - 10 = 10000 不屬於 [10001,10100]

tr[10011] = [10011,10011]
tr[10010] = [10001,10010]

tr[10100] = 
a[10100] + tr[10011] + tr[10010]

可以發現就是在不超過 \([10100 - lowbit(10100) + 1, 10100]\) 範圍內,原陣列的本身加上它本身 - 1 不斷減 lowbit 的 tr 之和。
程式碼如此

for (int i = 1; i <= n; i ++ ) tr[i] = a[i]; // 先賦值a[i]

for (int x = 1; x <= n; x ++ )
    for (int i = x - 1; i >= x - lowbit(x) + 1; i -= lowbit(i)) tr[x] += tr[i];

程式碼

初始化

// 初始化

int a[], tr[], sum[]; // a[]原陣列 tr[]樹狀陣列 sum[]字首和

int lowbit(int x)
{
    return x & -x;
}

void add(int x, int k)
{
    for (int i = x; i <= n; i + lowbit(i)) tr[i] += k;
}

// (1) O(nlogn) 常用 使用修改的方式

for (int i = 1; i <= n; i ++ ) add(i, a[i]); 

// (2) O(n)  根據邊

for (int i = 1; i <= n; i ++ ) tr[i] = a[i]; // 先賦值a[i]

for (int x = 1; x <= n; x ++ )
    for (int i = x - 1; i >= x - lowbit(x) + 1; i -= lowbit(i)) tr[x] += tr[i];

// (3) O(n) 根據原理

for (int i = 1; i <= n; i ++ ) sum[i] = sum[i - 1] + a[i]; 

for (int i = 1; i <= n; i ++ ) tr[i] = sum[i] - sum[i - lowbit(i)];

修改/查詢

void add(int x, int k) // 讓第x個數加上k
{
    for (int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}

int sum(int x) // 求出為第1 - x的和
{
    int res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

應用

特殊的應用,比如支援區間修改 + 區間查詢。這個的常數是對應線段樹的四分之一。

![[Pasted image 20241125081846.png|303]]

差分 + 樹狀陣列可以把區間修改最佳化到O(logn)
區間和 \(= a[i] + a[i + 1] + … + a[x] = b[1 … i] + b[1 … i + 1] + … b[1 … x]\) 最後可以把區間求和最佳化到 \(O(logn)\)

如這題 P11217 【MX-S4-T1】「yyOI R2」youyou 的垃圾桶 - 洛谷 | 電腦科學教育新生態 (luogu.com.cn) 如果使用線段樹的話 \(O(q {\log ^2 n})\) 一定過不了,而用樹狀陣列就可以。線段樹需要 \(O(q\log n)\) 並且常數要小。

程式碼

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long LL ;

const int N = 100010;

int n, m;
int a[N];
LL tr1[N]; // b[i] 的字首和
LL tr2[N]; // b[i] * i 的字首和

int lowbit(int x)
{
    return x & -x;
}

void add(LL tr[], int x, LL k)
{
    for (int i = x; i <= n; i += lowbit(i)) tr[i] += k;
}

LL sum(LL tr[], int x)
{
    LL res = 0;
    for (int i = x; i; i -= lowbit(i)) res += tr[i];
    return res;
}

LL p_sum(int x)
{
    return sum(tr1, x) * (x + 1) - sum(tr2, x); 
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &a[i]);

    for (int i = 1; i <= n; i ++ )
    {
        int b = a[i] - a[i - 1];
        add(tr1, i, b);
        add(tr2, i, (LL)b * i);
    }

    while (m -- )
    {
        int l, r, d;
        char op[2];
        scanf("%s%d%d", op, &l, &r);
        if (op[0] == 'C')
        {
            scanf("%d", &d);
            add(tr1, l, d), add(tr2, l, l * d);
            add(tr1, r + 1, -d), add(tr2, r + 1, (r + 1) * -d); 
        }
        else printf("%lld\n", p_sum(r) - p_sum(l - 1));   
    }


    return 0;
}

相關文章