FHQ-treap模板

lyrrr發表於2024-11-25

可以再加一個struct把整個樹封裝起來。。跟oiwiki學的

#include<bits/stdc++.h>
using namespace std;

#include<bits/stdc++.h>
using namespace std;

struct Node{
    Node *ch[2];
    int val, prio, cnt, siz;

    Node(int _val) : val(_val), cnt(1), siz(1){
        ch[0] = ch[1] = nullptr;
        prio = rand();
    }

    //   Node(Node *_node) {
    //     val = _node->val, prio = _node->prio, cnt = _node->cnt, siz = _node->siz;
    //   }

    void upd_siz(){
        siz = cnt;
        if (ch[0] != nullptr) siz += ch[0]->siz;
        if (ch[1] != nullptr) siz += ch[1]->siz;
    }
};

Node *rt;

pair<Node *, Node *> split(Node *cur, int key)
{
    if (cur == nullptr)
        return {nullptr, nullptr};
    if (cur->val <= key)
    {
        auto tmp = split(cur->ch[1], key);
        cur->ch[1] = tmp.first;
        cur->upd_siz();
        return {cur, tmp.second};
    }
    else
    {
        auto tmp = split(cur->ch[0], key);
        cur->ch[0] = tmp.second;
        cur->upd_siz();
        return {tmp.first, cur};
    }
}

tuple<Node *, Node *, Node *> split_by_rk(Node *cur, int rk)
{
    if (cur == nullptr)
        return {nullptr, nullptr, nullptr};
    int ls_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz;
    if (rk <= ls_siz)
    {
        Node *l, *mid, *r;
        tie(l, mid, r) = split_by_rk(cur->ch[0], rk);
        cur->ch[0] = r;
        cur->upd_siz();
        return {l, mid, cur};
    }
    else if (rk <= ls_siz + cur->cnt)
    {
        Node *lt = cur->ch[0];
        Node *rt = cur->ch[1];
        cur->ch[0] = cur->ch[1] = nullptr;
        return {lt, cur, rt};
    }
    else
    {
        Node *l, *mid, *r;
        tie(l, mid, r) = split_by_rk(cur->ch[1], rk - ls_siz - cur->cnt);
        cur->ch[1] = l;
        cur->upd_siz();
        return {cur, mid, r};
    }
}

Node *merge(Node *u, Node *v)
{
    if (u == nullptr && v == nullptr)
        return nullptr;
    if (u != nullptr && v == nullptr)
        return u;
    if (v != nullptr && u == nullptr)
        return v;
    if (u->prio < v->prio)
    {
        u->ch[1] = merge(u->ch[1], v);
        u->upd_siz();
        return u;
    }
    else
    {
        v->ch[0] = merge(u, v->ch[0]);
        v->upd_siz();
        return v;
    }
}

void insert(int val)
{
    auto tmp = split(rt, val);
    auto l_tr = split(tmp.first, val - 1);
    Node *new_node;
    if (l_tr.second == nullptr)
    {
        new_node = new Node(val);
    }
    else
    {
        l_tr.second->cnt++;
        l_tr.second->upd_siz();
    }
    Node *l_tr_combined =
        merge(l_tr.first, l_tr.second == nullptr ? new_node : l_tr.second);
    rt = merge(l_tr_combined, tmp.second);
}

void del(int val)
{
    auto tmp = split(rt, val);
    auto l_tr = split(tmp.first, val - 1);
    if (l_tr.second->cnt > 1)
    {
        l_tr.second->cnt--;
        l_tr.second->upd_siz();
        l_tr.first = merge(l_tr.first, l_tr.second);
    }
    else
    {
        if (tmp.first == l_tr.second)
        {
            tmp.first = nullptr;
        }
        delete l_tr.second;
        l_tr.second = nullptr;
    }
    rt = merge(l_tr.first, tmp.second);
}

int qrk_by_val(Node *cur, int val)
{
    auto tmp = split(cur, val - 1);
    int ret = (tmp.first == nullptr ? 0 : tmp.first->siz) + 1;
    rt = merge(tmp.first, tmp.second);
    return ret;
}

int qval_by_rk(Node *cur, int rk)
{
    Node *l, *mid, *r;
    tie(l, mid, r) = split_by_rk(cur, rk);
    int ret = mid->val;
    rt = merge(merge(l, mid), r);
    return ret;
}

int qpre(int val)
{
    auto tmp = split(rt, val - 1);
    int ret = qval_by_rk(tmp.first, tmp.first->siz);
    rt = merge(tmp.first, tmp.second);
    return ret;
}

int qnxt(int val)
{
    auto tmp = split(rt, val);
    int ret = qval_by_rk(tmp.second, 1);
    rt = merge(tmp.first, tmp.second);
    return ret;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    srand(time(nullptr));
    int t;
    cin >> t;

    while (t--)
    {
        int mod, x;
        cin >> mod >> x;
        if (mod == 1)
            insert(x);
        if (mod == 2)
            del(x);
        if (mod == 3)
            cout << qrk_by_val(rt, x) << endl;
        if (mod == 4)
            cout << qval_by_rk(rt, x) << endl;
        if (mod == 5)
            cout << qpre(x) << endl;
        if (mod == 6)
            cout << qnxt(x) << endl;
    }
}