樹狀陣列詳解

RioTian 發表於 2020-08-01

簡介

樹狀陣列和下面的線段樹可是親兄弟了,但他倆畢竟還有一些區別:
樹狀陣列能有的操作,線段樹一定有;
線段樹有的操作,樹狀陣列不一定有。

這麼看來選擇 線段樹 不就 「得天下了」

事實上,樹狀陣列的程式碼要比線段樹短得多,思維也更清晰,在解決一些單點修改的問題時,樹狀陣列是不二之選。


原理

如果要具體瞭解樹狀陣列的工作原理,請看下面這張圖:

樹狀陣列詳解

這個結構的思想和線段樹有些類似:用一個大節點表示一些小節點的資訊,進行查詢的時候只需要查詢一些大節點而不是更多的小節點。

最下面的八個方塊 (標有數字的方塊) 就代表存入 \(a\) 中的八個數,現在都是十進位制。

他們上面的參差不齊的剩下的方塊就代表 \(a\) 的上級—— \(c\) 陣列。

很顯然看出:
\(c[2]\) 管理的是 \(a[1]\) & \(a[2]\)
\(c[4]\) 管理的是 \(a[1]\) & \(a[2]\) & \(a[3]\) & \(a[4]\)
\(c[6]\) 管理的是 \(a[5]\) & \(a[6]\)\(c[8]\) 則管理全部 \(8\) 個數。

所以,如果你要算區間和的話,比如說要算 \(a[51]\) ~ \(a[91]\) 的區間和,暴力算當然可以,那上百萬的數,那就 RE 嘍。

那麼這種類似於跳一跳的連續跳到中心點而分值不斷變大的原理是一樣的(倍增)。

你從 \(91\) 開始往前跳,發現 \(c[n]\)\(n\) 我也不確定是多少,算起來太麻煩,就意思一下)只管 \(a[91]\) 這個點,那麼你就會找 \(a[90]\) ,發現 \(c[n - 1]\) 管的是 \(a[90]\) & \(a[89]\) ;那麼你就會直接跳到 \(a[88]\)\(c[n - 2]\) 就會管 \(a[81]\) ~ \(a[88]\) 這些數,下次查詢從 \(a[80]\) 往前找,以此類推。


用法及操作

那麼問題來了,你是怎麼知道 \(c\) 管的 \(a\) 的個數分別是多少呢?你那個 \(1\) 個, \(2\) 個, \(8\) 個……是怎麼來的呢?
這時,我們引入一個函式—— lowbit

int lowbit(int x) {
    //算出x二進位制的從右往左出現第一個1以及這個1之後的那些0組成數的二進位制對應的十進位制的數
    return x & -x;
}

lowbit 的意思註釋說明了,我們們就用這個說法來證明一下 \(a[88]\)
\(88_{(10)}=1011000_{(2)}\)
發現第一個 \(1\) 以及他後面的 \(0\) 組成的二進位制是 \(1000\)
\(1000_{(2)} = 8_{(10)}\)
\(1000\) 對應的十進位制是 \(8\) ,所以 \(c\) 一共管理 \(8\)\(a\)

這就是 lowbit 的用處,僅此而已(但也相當有用)。

你可能又問了:x & -x 是什麼意思啊?

在一般情況下,對於 int 型的正數,最高位是 0,接下來是其二進位制表示;而對於負數 (-x),表示方法是把 x 按位取反之後再加上 1 (補碼知識)。

例如 :
\(x =88_{(10)}=01011000_{(2)}\)
\(-x = -88_{(10)} = (10100111_{(2)} + 1_{(2)}) =10101000_{(2)}\)
\(x\ \& \ (-x) = 1000_{(2)} = 8_{(10)}\)

那麼對於 單點修改 就更輕鬆了:

void add(int x, int k) {
    while (x <= n) {  //不能越界
        c[x] = c[x] + k;
        x = x + lowbit(x);
    }
}

每次只要在他的上級那裡更新就行,自己就可以不用管了。

int getsum(int x) {  // a[1]……a[x]的和
    int ans = 0;
    while (x >= 1) {
        ans = ans + c[x];
        x = x - lowbit(x);
    }
    return ans;
}

區間加 & 區間求和

若維護序列 \(a\) 的差分陣列 \(b\) ,此時我們對 \(a\) 的一個字首 \(r\) 求和,即 \(\sum_{i=1}^{r} a_i\) ,由差分陣列定義得 \(a_i=\sum_{j=1}^i b_j\)

進行推導

\[\sum_{i=1}^{r} a_i\\=\sum_{i=1}^r\sum_{j=1}^i b_j\\=\sum_{i=1}^r b_i\times(r-i+1) \\=\sum_{i=1}^r b_i\times (r+1)-\sum_{i=1}^r b_i\times i \]

區間和可以用兩個字首和相減得到,因此只需要用兩個樹狀陣列分別維護 \(\sum b_i\)\(\sum i \times b_i\) ,就能實現區間求和。

程式碼如下

int t1[MAXN], t2[MAXN], n;

inline int lowbit(int x) { return x & (-x); }

void add(int k, int v) {
    int v1 = k * v;
    while (k <= n) {
        t1[k] += v, t2[k] += v1;
        k += lowbit(k);
    }
}

int getsum(int* t, int k) {
    int ret = 0;
    while (k) {
        ret += t[k];
        k -= lowbit(k);
    }
    return ret;
}

void add1(int l, int r, int v) {
    add(l, v), add(r + 1, -v);  //將區間加差分為兩個字首加
}

long long getsum1(int l, int r) {//1ll :代表長整型的 1
    return (r + 1ll) * getsum(t1, r) - 1ll * l * getsum(t1, l - 1) -
        (getsum(t2, r) - getsum(t2, l - 1));
}

Tricks

\(O(n)\) 建樹:

每一個節點的值是由所有與自己直接相連的兒子的值求和得到的。因此可以倒著考慮貢獻,即每次確定完兒子的值後,用自己的值更新自己的直接父親。

// O(n)建樹
void init() {
    for (int i = 1; i <= n; ++i) {
        t[i] += a[i];
        int j = i + lowbit(i);
        if (j <= n) t[j] += t[i];
    }
}

\(O(\log n)\) 查詢第 \(k\) 小/大元素。在此處只討論第 \(k\) 小,第 \(k\) 大問題可以通過簡單計算轉化為第 \(k\) 小問題。

參考 "可持久化線段樹" 章節中,關於求區間第 \(k\) 小的思想。將所有數字看成一個可重集合,即定義陣列 \(a\) 表示值為 \(i\) 的元素在整個序列重出現了 \(a_i\) 次。找第 \(k\) 大就是找到最小的 \(x\) 恰好滿足 \(\sum_{i=1}^{x}a_i \geq k\)

因此可以想到演算法:如果已經找到 \(x\) 滿足 \(\sum_{i=1}^{x}a_i \le k\) ,考慮能不能讓 \(x\) 繼續增加,使其仍然滿足這個條件。找到最大的 \(x\) 後, \(x+1\) 就是所要的值。
在樹狀陣列中,節點是根據 2 的冪劃分的,每次可以擴大 2 的冪的長度。令 \(sum\) 表示當前的 \(x\) 所代表的字首和,有如下演算法找到最大的 \(x\)

  1. 求出 \(depth=\left \lfloor log_2n \right \rfloor\)
  2. 計算 \(t=\sum_{i=x+1}^{x+2^{depth}}a_i\)
  3. 如果 \(sum+t \le k\) ,則此時擴充套件成功,將 \(2^{depth}\) 累加到 \(x\) 上;否則擴充套件失敗,對 \(x\) 不進行操作
  4. \(depth\) 減 1,回到步驟 2,直至 \(depth\) 為 0
//權值樹狀陣列查詢第k小
int kth(int k) {
  int cnt = 0, ret = 0;
  for (int i = log2(n); ~i; --i) {      // i與上文depth含義相同
    ret += 1 << i;                      //嘗試擴充套件
    if (ret >= n || cnt + t[ret] >= k)  //如果擴充套件失敗
      ret -= 1 << i;
    else
      cnt += t[ret];  //擴充套件成功後 要更新之前求和的值
  }
  return ret + 1;
}

時間戳優化:

對付多組資料很常見的技巧。如果每次輸入新資料時,都暴力清空樹狀陣列,就可能會造成超時。因此使用 \(tag\) 標記,儲存當前節點上次使用時間(即最近一次是被第幾組資料使用)。每次操作時判斷這個位置 \(tag\) 中的時間和當前時間是否相同,就可以判斷這個位置應該是 0 還是陣列內的值。

//權值樹狀陣列查詢第k小
int kth(int k) {
    int cnt = 0, ret = 0;
    for (int i = log2(n); ~i; --i) {      // i與上文depth含義相同
        ret += 1 << i;                      //嘗試擴充套件
        if (ret >= n || cnt + t[ret] >= k)  //如果擴充套件失敗
            ret -= 1 << i;
        else
            cnt += t[ret];  //擴充套件成功後 要更新之前求和的值
    }
    return ret + 1;
}

例題