P1501 [國家集訓隊]Tree II

qq_45323960發表於2020-10-13

題目連結
題目描述
一棵 n n n 個點的樹,每個點的初始權值為 1 1 1
對於這棵樹有 q q q 個操作,每個操作為以下四種操作之一:

  • + u v c:將 u u u v v v 的路徑上的點的權值都加上自然數 c c c

  • - u1 v1 u2 v2:將樹中原有的邊 ( u 1 , v 1 ) (u_1,v_1) (u1,v1) 刪除,加入一條新邊 ( u 2 , v 2 ) (u_2,v_2) (u2,v2),保證操作完之後仍然是一棵樹;

  • u v c:將 u u u v v v 的路徑上的點的權值都乘上自然數 c c c

  • / u v:詢問 u u u v v v 的路徑上的點的權值和,將答案對 51061 51061 51061 取模。

輸入格式
第一行兩個整數 n , q n,q n,q

接下來 n − 1 n-1 n1 行每行兩個正整數 u , v u,v u,v,描述這棵樹的每條邊。

接下來 q q q 行,每行描述一個操作。

輸出格式
對於每個詢問操作,輸出一行一個整數表示答案。

輸入輸出樣例
輸入 #1

3 2
1 2
2 3
* 1 3 4
/ 1 1

輸出 #1

4

說明/提示
【資料範圍】
對於 100 % 100\% 100% 的資料, 1 ≤ n , q ≤ 1 0 5 1\le n,q \le 10^5 1n,q105 0 ≤ c ≤ 1 0 4 0\le c \le 10^4 0c104
分別設一個加和乘的lazy標籤維護

易錯點:

  • 先維護乘再維護加,因為有乘法分配律而加法不滿足分配律,例如 a × b + c a×b+c a×b+c 中不能把 c c c 放到 a × b a×b a×b 裡面。
  • 在轉移lazy標籤時要維護子樹權值和 s u m sum sum ,且 s u m sum sum 的維護要比權值 v a l val val 多一層。因為在splay改變時 s u m sum sum 的更新需要本節點的 v a l val val 和子節點的 s u m sum sum
  • 每次新增lazy標籤的根節點不需要維護 s u m sum sum 。因為splay的操作一定要經過該點,因此該點的 s u m sum sum 會被更新。
#include<bits/stdc++.h>

using namespace std;

inline int qr() {
    int f = 0, fu = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-')fu = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        f = (f << 3) + (f << 1) + c - 48;
        c = getchar();
    }
    return f * fu;
}

const int N = 1e5 + 10, MOD = 51061;

struct LCT {
    struct node {
        int fa, ch[2], siz;
        unsigned int val, sum, add, mul;
        bool v;
    } tr[N];

    inline void upd(int p) {
        tr[p].siz = tr[tr[p].ch[0]].siz + tr[tr[p].ch[1]].siz + 1;
        tr[p].sum = (tr[tr[p].ch[0]].sum + tr[tr[p].ch[1]].sum + tr[p].val) % MOD;
    }

    inline bool get(int p) {
        return p == tr[tr[p].fa].ch[1];
    }

    inline void spd(int p) {
        if (tr[p].v) {
            tr[p].v = false;
            swap(tr[p].ch[0], tr[p].ch[1]);
            if (tr[p].ch[0])tr[tr[p].ch[0]].v ^= 1;
            if (tr[p].ch[1])tr[tr[p].ch[1]].v ^= 1;
        }
        if (tr[p].mul != 1) {
            if (tr[p].ch[0]) {
                tr[tr[p].ch[0]].mul = (tr[tr[p].ch[0]].mul * tr[p].mul) % MOD;
                tr[tr[p].ch[0]].add = (tr[tr[p].ch[0]].add * tr[p].mul) % MOD;
                tr[tr[p].ch[0]].sum = (tr[tr[p].ch[0]].sum * tr[p].mul) % MOD;
            }
            if (tr[p].ch[1]) {
                tr[tr[p].ch[1]].mul = (tr[tr[p].ch[1]].mul * tr[p].mul) % MOD;
                tr[tr[p].ch[1]].add = (tr[tr[p].ch[1]].add * tr[p].mul) % MOD;
                tr[tr[p].ch[1]].sum = (tr[tr[p].ch[1]].sum * tr[p].mul) % MOD;
            }
            tr[p].val = (tr[p].val * tr[p].mul) % MOD;
            tr[p].mul = 1;
        }
        if (tr[p].add) {
            if (tr[p].ch[0]) {
                tr[tr[p].ch[0]].add = (tr[tr[p].ch[0]].add + tr[p].add) % MOD;
                tr[tr[p].ch[0]].sum = (tr[tr[p].ch[0]].sum + 1ll * tr[tr[p].ch[0]].siz * tr[p].add) % MOD;
            }
            if (tr[p].ch[1]) {
                tr[tr[p].ch[1]].add = (tr[tr[p].ch[1]].add + tr[p].add) % MOD;
                tr[tr[p].ch[1]].sum = (tr[tr[p].ch[1]].sum + 1ll * tr[tr[p].ch[1]].siz * tr[p].add) % MOD;
            }
            tr[p].val = (tr[p].val + tr[p].add) % MOD;
            tr[p].add = 0;
        }
    }


    inline bool isrt(int p) {
        return tr[tr[p].fa].ch[0] != p && tr[tr[p].fa].ch[1] != p;
    }

    inline void rot(int p) {
        int x = tr[p].fa, y = tr[x].fa, u = get(p), v = get(x);
        bool o = isrt(x);
        tr[tr[p].ch[u ^ 1]].fa = x, tr[x].ch[u] = tr[p].ch[u ^ 1];
        tr[x].fa = p, tr[p].ch[u ^ 1] = x, upd(x), upd(p);
        if ((tr[p].fa = y) && !o)tr[y].ch[v] = p;
    }


    inline void pre(int p) {
        stack<int> st;
        while (!isrt(p))st.push(p), p = tr[p].fa;
        spd(p);
        while (!st.empty())spd(st.top()), st.pop();
    }

    inline void splay(int p) {
        pre(p);
        for (int x = tr[p].fa; !isrt(p); rot(p), x = tr[p].fa)
            if (!isrt(x))rot(get(p) == get(x) ? x : p);
    }

    inline void access(int p) {
        for (int x = 0; p; p = tr[x = p].fa)splay(p), tr[p].ch[1] = x, upd(p);
    }

    inline void mkrt(int p) {
        access(p), splay(p), tr[p].v ^= 1;
    }

    inline void split(int x, int y) {
        mkrt(x), access(y), splay(y);
    }

    inline int fdrt(int p) {
        access(p), splay(p);
        while (tr[p].ch[0])spd(p), p = tr[p].ch[0];
        return splay(p), p;
    }

    inline bool link(int x, int y) {
        mkrt(x);
        return fdrt(y) != x && (tr[x].fa = y);
    }

    inline bool cut(int x, int y) {
        mkrt(x);
        if (fdrt(y) != x || tr[x].siz != 2)return false;
        return tr[y].fa = tr[x].ch[1] = 0, upd(x), true;
    }

    inline void init(int n) {
        for (int i = 1; i <= n; i++)tr[i].siz = tr[i].mul = tr[i].val = 1;
    }
} L;

int n, m;
char o[2];

int main() {
    n = qr(), m = qr(), L.init(n);
    for (int i = 1; i <= n - 1; i++)L.link(qr(), qr());
    while (m--) {
        scanf("%s", o);
        if (o[0] == '-')L.cut(qr(), qr()), L.link(qr(), qr());
        else {
            int x = qr(), y = qr();
            L.split(x, y);
            if (o[0] == '/') {
                printf("%u\n", L.tr[y].sum);
                continue;
            }
            o[0] == '+' ? L.tr[y].add = qr() : L.tr[y].mul = qr();
        }
    }
    return 0;
}

相關文章