可持久化線段樹

blind5883發表於2024-11-09

少寫了一點,可持久化的好處就是可以用較低的代價去得到可以變換版本這一功能。

可持久化線段樹(主席樹) 帶註釋的程式碼

/*
    注意, 可持久化線段樹很難支援區間修改, 一般涉及區間修改的時候不用
    單點修改是可以的
    
    一樣, 直接選這題不大好, 看下面的通用模版, 具有通用性, 不要被這題禁錮了思想 
    
    思路還算簡單
    首先你需要一個線段樹的框架, 即root[0], 因為可持久化線段樹基本不改變其框架
    只改變其中的資訊, 比如最大值最小值, 這裡線段樹可以存下標範圍, 也可以存值域, 只要線段樹的框架不變就行了
    但是這時候線段樹裡面的l, r 就存的是它的左右兒子, 而非, 左右邊界, 相當於指標(一般用陣列進行模擬), 這個點的左右邊界可以透過遞迴傳下來
    
    可持久化的思路, 每次只修改新加入的點(對與上一個版本), 因此, 我們可以先複製一份
    一模一樣的, 即tr[q] = tr[p] (p是上一版本, q是這一版本)
    這時候只修改要修改的(比如這個點左兒子的資訊, 右兒子不變)就行了, 這樣依舊可以搜到之前版本中沒修改的點, 也不會更改之前版本
    對於修改的點, 每次新建立一個新點(每個版本最多會建立logn個點, 加上之前的骨架, 一共就是(n * 4 + nlogn)個點)
    
    對於這題
    首先得保證線段樹的結構不變, 因為沒有修改, 所以結構肯定的不變,
    但是為了好算第k小數, 我們可持久化線段樹存的就不是陣列的下標範圍,
    而是一個值域, 離散化後的值域, 因為點的數量n是固定的, 那麼這個值域(離散化後的)就是固定的
    就不會改變這個線段樹的框架, 同時保序離散化後, 是從小到大排的, 這就方便我們進行求第k小數
    
    這題線段樹裡面多存一個cnt, 表示這個區間裡面的數量, 因為這裡保序離散化了, 
    如果一個點左兒子內的數量大於等於k, 那麼總體第k小數, 就在左兒子裡面, 否則就在右兒子裡面
    
    這兩句引用我程式碼裡的註釋
    如果左兒子cnts >= k, 那麼總體第k小的, 在左兒子裡還是第k小的
    如果左兒子cnts不滿足k個, 說明第k小肯定不在左兒子裡面, 就去右兒子裡面找, 並且, 因為左兒子裡面有cnts個數, 那麼在右兒子裡面, 總體第k小數, 應該是第k - cnts小的數
    
    而對於限制[l, r], 對於r可以直接在第r個版本去搜數量
    對於[l, r], 可以使用字首和的思想, 在第l - 1個版本里面的這個區間的數量cnt1, 第r個版本的數量cnt2, cnt2 - cnt1就是[l, r]版本內這個區間新加的數量
    也就是第[l, r]個數直接的數量, 這樣區間限制就解決了(注意可持久化線段樹, 總體框架不變, 每個點在不同版本都有對應的)
    
    至此此題結束
    
    具體見程式碼
*/
#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 100010, M = 10010;

int n, m;
int id[N], w[N], idx, cnt; 
int root[N]; // 每個版本的入口

struct Node // 實際上線段樹每個節點存的是一個"值"域
{
    int l, r; // 這裡的l和r, 不是左右邊界, 是左右兒子的下標(idx)
    int cnt; // 每個值域裡面數的數量
}tr[N * 4 + 17 * N]; // logn大約是17,

int find(int x) // 保序離散化
{
    int l = 1, r = cnt;
    while (l < r)
    {
        int mid = l + r >> 1;
        if (x <= id[mid]) r = mid;
        else l = mid + 1;
    }
    
    return l;
}

int build(int l, int r) // 建立基本骨架
{
    int p = ++ idx;
    if (l == r) return p;
    
    int mid = l + r >> 1;
    tr[p].l = build(l, mid);
    tr[p].r = build(mid + 1, r);
    
    return p;
}

int insert(int p, int l, int r, int x)
{
    int q = ++ idx;
    tr[q] = tr[p]; // 複製這一點, 注意這裡tr[p]內已經有值了
    if (l == r) // 說明搜到這一點了, 把這一點的數量++
    {
        tr[q].cnt ++ ;
        return q;
    }
    
    int mid = l + r >> 1;
    if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x); // 更改左兒子
    else tr[q].r = insert(tr[p].r, mid + 1, r, x); // 更改右兒子
    
    tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt; // 計算cnt, 這裡可以寫一個pushup
    
    return q;
}


int query(int p, int q, int l, int r, int k)
{
    int cnts = tr[tr[p].l].cnt - tr[tr[q].l].cnt; // 這裡我命名的是cnts, 別寫錯了
    int mid = l + r >> 1;
    if (l == r) return l; // 說明找到這個點了, 返回的是l, 是第幾個點
    
    if (cnts >= k) return query(tr[p].l, tr[q].l, l, mid, k); // 如果cnts >= k, 那麼總體第k小的, 在左兒子裡還是第k小的
    else query(tr[p].r, tr[q].r, mid + 1, r, k - cnts); // 如果cnts不滿足k個, 說明第k小肯定不在左兒子裡面, 就去右兒子裡面找, 並且, 因為左兒子裡面有cnts個數, 那麼在右兒子裡面, 總體第k小數, 應該是第k - cnts小的數
}

int main()
{
    cin >> n >> m; 
    
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]), id[ ++ cnt] = w[i];
    
    sort(id + 1, id + 1 + cnt);
    cnt = unique(id + 1, id + 1 + cnt) - id - 1; // 判重
    
    root[0] = build(1, cnt); // 初始框架

    for (int i = 1; i <= n; i ++ ) root[i] = insert(root[i - 1], 1, cnt, find(w[i]));  // 每個版本
    
    while (m -- )
    {
        int l, r, k;
        scanf("%d%d%d", &l, &r, &k);
        printf("%d\n", id[query(root[r], root[l - 1], 1, cnt, k)]); // 注意每次輸出的是原數而不是離散化之後的
    }
    
    return 0;
}


通用模版

洛谷P3919

/*
    不容易, 算是打出來模版了
    還是模版好理解一些, 上來就第k小數, 其實並不能真正瞭解這個
    這個模版可以讓你更加細緻了理解可持久化線段樹
    洛谷P3919 【模板】可持久化線段樹 1(可持久化陣列)
    
    根據題意線段樹裡面隨便存個值就行, 最大值也行, 最小也行, 只存葉節點的也行
    剛開始的時候給我看傻了, 後邊一想, 好像存什麼都行, 這裡存最大值
    
    為什麼可持久化線段樹裡面必須存左右兒子下標?
    可持久化線段樹裡面的l和r都是代表左右兒子, 為什麼, 因為可持久化線段樹裡面有很多版本, 每個版本有新點, 
    新點會打亂下標順序, 就不能透過堆的方式來找到左右兒子, 所以要存左右兒子的下標來找到它們, 這是必須存的
    而左右邊界, 你可以在結構體另開新的變數存, 也可以透過遞迴傳下去, 這裡使用遞迴傳下去, 思想要開啟
    
    注意題意, 剩下看模版把, 其實和普通線段樹差不太多
*/
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <cstdio>

using namespace std;

const int N = 1000010; 

int n, m, cnt;
int root[N], w[N], idx;

struct Node
{
    int l, r; // 左右兒子
    int v; // 這個區間的最大數
}tr[N * 4 + N * (int)ceil(log(N) / log(2))];

void pushup(int p)
{
    tr[p].v = max(tr[tr[p].l].v, tr[tr[p].r].v);
}

int build(int l, int r) // 這裡題目中給出了初始化版本了, 所以初始版本要更新最大值
{
    int p = ++ idx;
    if (l == r) 
    {
        tr[p].v = w[l]; // 葉節點存當前值
        return p;
    }
    int mid = l + r >> 1;
    tr[p].l = build(l, mid);
    tr[p].r = build(mid + 1, r);
    pushup(p); // 就是不pushup也可以過掉, 注意上面說的, 不pushup還快, 被卡了可以試試a
    return p;
}

int insert(int p, int l, int r, int x, int k)
{
    int q = ++ idx;  // 開新點
    tr[q] = tr[p]; // 複製
    
    if (l == r) 
    {
        tr[q].v = k;
        return q;
    }
    
    int mid = l + r >> 1;
    if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x, k);
    else tr[q].r = insert(tr[p].r, mid + 1, r, x, k);
    pushup(q); // 就是不pushup也可以過掉
    return q;
}

int query(int p, int l, int r, int x)
{
    if (l == r) return tr[p].v;
    int mid = l + r >> 1;
    
    if (x <= mid) return query(tr[p].l, l, mid, x);
    else return query(tr[p].r, mid + 1, r, x);
}

int main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
    
    root[0] = build(1, n); // 初始版本
    
    while (m -- )
    {
        int v, op, d, x;
        scanf("%d%d", &v, &op);
        if (op == 1) 
        {
            scanf("%d%d", &x, &d);
            root[ ++ cnt] = insert(root[v], 1, n, x, d); // 新開版本
        }
        else
        {
            scanf("%d", &x);
            root[ ++ cnt] = root[v]; // 根據題意新開版本
            printf("%d\n", query(root[v], 1, n, x));
        }
    }
    return 0;
}

截圖

螢幕截圖 2023-10-30 215337.png
螢幕截圖 2023-10-31 075836.png
螢幕截圖 2023-10-31 075855.png

對於這題的另一種做法

自己想出來的,感覺要容易想到,時間上要比y的慢一倍。大體思想就是,我們從小到大依次加入一個數,每加入一個就記錄一個版本,線段樹裡記錄區間裡數的數量,在查詢時,只要二分出區間數的數量大於等於k的最小版本即可,這個版本對應插入的點就是要求的第 k 小點,時間複雜度是 \(O(n\log^2n)\) 的和 y 是一個量級的,可能是由於常數問題,所以執行上要慢。

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

const int N = 100010;

int n, m;
int idx, root[N], cnt;
int g[N];

struct node
{
    int v, id;
    bool operator<(const node &W)const
    {
        return v < W.v;
    }
}a[N];

struct Node
{
    int l, r;
    int v, sum = 0;
}tr[N * 4 + N * (int)ceil(log2(N))];

void pushup(int u)
{
    int &l = tr[u].l, &r = tr[u].r;
    tr[u].sum = tr[l].sum + tr[r].sum;
}

int build(int l, int r)
{
    int p = ++ idx;
    if (l == r)
    {
        tr[p].v = -0x3f3f3f3f;
        tr[p].sum = 0;
        return p;
    }
    int mid = l + r >> 1;
    tr[p].l = build(l, mid);
    tr[p].r = build(mid + 1, r);
    pushup(p);
    return p;
}

int insert(int p, int l, int r, int x, int k)
{
    int q = ++ idx;
    tr[q] = tr[p];
    if (l == r)
    {
        tr[q].v = k;
        if (k > -0x3f3f3f3f) tr[q].sum = 1;
        return q;
    }
    int mid = l + r >> 1;
    if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x, k);
    else tr[q].r = insert(tr[p].r, mid + 1, r, x, k);
    pushup(q);
    return q;
}

int query(int p, int l, int r, int x, int y)
{
    if (x <= l && r <= y) return tr[p].sum;
    
    int mid = l + r >> 1;
    int sum = 0;
    if (x <= mid) sum += query(tr[p].l, l, mid, x, y);
    if (y > mid) sum += query(tr[p].r, mid + 1, r, x, y);
    
    return sum;
}

bool check(int x, int l, int r, int k)
{
    return query(root[x], 1, n, l, r) >= k;
}

int main()
{
    cin >> n >> m;
    
    root[0] = build(1, n);
    for (int i = 1; i <= n; i ++ ) 
    {
        int x;
        scanf("%d", &x);
        a[i] = {x, i};
        g[i] = x;
    }
    
    sort(a + 1, a + n + 1);
    
    for (int i = 1; i <= n; i ++ ) 
    {
        root[i] = insert(root[i - 1], 1, n, a[i].id, a[i].v);
        // cout << i << endl;
    }
    
    while (m -- )
    {
        int ls, rs, k;
        scanf("%d%d%d", &ls, &rs, &k);
        
        int l = 0, r = n, mid;
        while (l < r)
        {
            mid = l + r >> 1;
            if (check(mid, ls, rs, k)) r = mid;
            else l = mid + 1;
        }
        
        printf("%d\n", a[l].v);
    }
    
    // cout << query(root[5], 1, n, 2, 5);
    
    
    return 0;
    
}

相關文章