ST表

blind5883發表於2024-05-25

有點時間補一下這玩意吧。

首先先說明 RMQ 是一類問題, 指 區間最大最小值, 而ST表是解解決 RMQ 問題的一把手術刀(手術刀, 鋒利但不通用)。

作用

\(O(\log n)\) 的預處理
\(O(1)\)區間最大值查詢
不可以更改區間數值

原理

原理是倍增

我們將設 \(f[i][j]\) 是從 \(i\) 處向外包含 \(2^j\) 個數中的最大值。即 \([i, i + 2^j - 1]\) 中的最大值,原陣列為 \(a[i]\),其中 \(f[i][0] = a[i]\)

因為任何區間長度 \(len\),無論怎麼分,其最多隻需要兩個 \(f[i][j]\) 就可以完全覆蓋它。

證明:區間長度 \(len\),可以分出一個比 \(len\) 小的最大的 \(2^n\),那麼 \(len - 2^n < 2^n\), 如果不符合這個的話, \(n\) 就可以繼續往上增加直到出現上述情況。因此一定有一個 \(n\) 使得 \(len - 2^n < 2^n\) 成立,而 \(f[i][j]\) 的覆蓋的長度為 \(2^j\),只要讓 \(j\) 等於上面的 \(n\),就可以最多用兩個 \(f\) 可以覆蓋其範圍,且不會超出邊界。也可以知道 \(n\) 就等於 \(\lfloor \log_2 len \rfloor\)

根據上面的證明我們也就知道如何去查詢了,設 \(f_1\) 一定包含這個區間的開頭,\(f_2\) 一定包含這個區間的末尾。大致如下圖。

設這個區間為 \([i, j]\),那麼區間長度 \(len\) 就是 \(j - i + 1\),證明中說到的 \(n\) 就等於 \(\log_2{len}\)
由圖可知 \(f_1\) 包含的區間就是 \([i, i + 2^n - 1]\)\(f_2\) 就是 \([j - 2^n + 1,j]\)。對應 \(f\) 陣列就是 \(f[i]i + 2^n - 1]\)\(f[j - 2^n + 1][j]\)

實現

問題來了怎麼實現它。

預處理

也很簡單,根據之前倍增 \(LCA\) 的思想(沒學過也可以),我們把一步拆成兩步走,即先跳 \(2^{j - 1}\) 步再跳 \(2^{j-1}\) 步,可得出遞推式 \(f[i][j] = f[f[i][j - 1]][j - 1]\),有了遞推式, 那麼求出它就很簡單了。

而求出 ST 表, 也就是預處理就是下面程式碼

for (int j = 0; j < M; j ++ ) // M是logn上取整, 即包含整個n
	for (int i = 1; i + (1 << j) - 1 <= n; i ++ )
		if (j == 0) f[i][j] = w[i]; // 如果只跳一步, 那麼最大值就是這個值它本身
		else f[i][j] = max(f[i][j], f[f[i][j - 1]][j - 1]); // 注意是求最大值

這個時間複雜度很好判斷, 最壞 \(O(n\log n)\),但是是很小的 \(\log n\),能從迴圈中看出來(手術刀)因為有預處理,所以是不能更改原陣列的,否則必須再次預處理(但這樣就不如用線段樹了)。

\(lg\) 陣列

上面提到了 \(lg\) 陣列,\(lg[i]\)\(\log_2i\) 下取整。

為什麼使用這個陣列而不是, 直接用函式 \(\log2()\) 呢?

這裡是為了保證查詢的 \(O(1)\),如果呼叫 \(\operatorname {log2()}\) 函式的話,時間複雜度會增加, 而透過預處理
\(lg\) 陣列的方式,就可以保證查詢 \(O(1)\)

\(lg\) 陣列一般預處理一遍即可,是 \(O(n\log n)\) 的時間複雜度,不會影響整體的預處理時間複雜度,可以直接加在上面的預處理裡面。

查詢

怎麼查詢呢?
你要知道, 最大值的區間是可以重疊的, 如 \([1, 5]\) 的最大值, 等於 \([1, 3]\) 的最大值和 \([2, 5]\) 的最大值的最大值, 雖然區間重疊了,但不影響答案的正確性,即最大值的區間是可以重疊的。

我們已經得到了 \(f_1,f_2\) (在上面原理中),根據上面的性質,那麼就很簡單了。
我們設 \(lg[i]\)\(\lfloor \log_2i \rfloor\) ,那麼從 \(i\)\(j\) 之間的長度是 \(len = j - i + 1\),最大值就是 \(\max(f[i][lg[len]], f[j - 2^{lg[i]} + 1][lg[len]])\)

其中 \(j - 2^{lg[i]} + 1\) ,這是 \(f_2\) 包含區間的開頭,比如 \([2, 5]\) 裡面有 \(4\) 個數,你從\(5\)\(4\)\(1\),但是你的區間是從 \(2\) 開始的,所以要加上 \(1\)。由區間 \([i, j]\) 長度計算公式 \(j - i + 1 = len\) 也可以得到 \(i = j - len + 1\) 這個式子。

程式碼

上面的要快一點點,下面的更好寫保證對,注意 \(\log2()\) 函式

int last = 0;
for (int i = 1; i <= n; i ++ )
{
	while (1 << last <= i) last ++ ; // 始終保證 2^last > i, 以便求出i的最小log2
	lg[i] = last - 1;
}

或者

for (int i = 1; i <= n; i ++ )
{
	lg[i] = log2(i);
}

例題

ST表(跳錶)

/* 
    中心思想: 倍增
    設f[i][j]是從i處向外2^j格里面的最大值;
    
    預處理是O(nlogn)
    查詢是O(1)的
    
    無法修改
    只能查詢
    像樹狀陣列一樣的"手術刀"
    
    因為查詢耗時O(1), 所以在"特殊情況"下沒法被O(log)的線段樹替代
*/

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

const int N = 200010, M = log2(N) + 1;

int n, m;
int w[N];
int f[N][M];
int lg[N];

void init()
{
    for (int j = 0; j < M; j ++ )
        for (int i = 1; i + (1 << j) - 1 <= n; i ++ )
            if (j == 0) f[i][j] = w[i];
            else f[i][j] = max(f[i][j - 1], f[i + (1 << j - 1)][j - 1]);

    int last = 0;
    for (int i = 1; i < N; i ++ ) // log陣列, 這裡的預處理是自己寫的, 利用一個last能幹好多事
    {
        while (1 << last <= i) last ++ ;
        lg[i] = last - 1;
    }
}

int query(int l, int r)
{
    int len = r - l + 1;
    return max(f[l][lg[len]], f[r - (1 << lg[len]) + 1][lg[len]]);
}

int main()
{
    cin >> n;
    for (int i = 1; i <= n; i ++ ) cin >> w[i];
    init();
    cin >> m;
    while (m -- )
    {
        int a, b;
        cin >> a >> b;
        cout << query(a, b) << endl;
    }
    
    return 0;
}

線段樹

/*
    線段樹的話, 比較簡單就不打註釋了
*/
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

using namespace std;

const int N = 200010;

int w[N];
int n, m;

struct Node
{
    int l, r;
    int maxv;
}tr[N * 4];

void pushup(int u)
{
    tr[u].maxv = max(tr[u << 1].maxv, tr[u << 1 | 1].maxv);
}

void build(int u, int l, int r)
{
    if (l == r) tr[u] = {l, l, w[l]};
    else
    {
        int mid = l + r >> 1;
        tr[u] = {l, r, -0x3f3f3f3f};
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

int query(int u, int l, int r)
{
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].maxv;
    else
    {
        int mid = tr[u].l + tr[u].r >> 1;
        int maxv = -0x3f3f3f3f;
        if (l <= mid) maxv = query(u << 1, l, r);
        if (r > mid) maxv = max(maxv, query(u << 1 | 1, l, r));
        
        return maxv;
    }
}

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
    build(1, 1, n);
    scanf("%d", &m);
    
    while (m -- )
    {
        int l, r;
        scanf("%d%d", &l, &r);
        printf("%d\n", query(1, l, r));
    }
    return 0;
}

相關文章