「KDOI-06-S」題解

FRZ_29發表於2024-09-15

T2 樹上異或

題面

分析

樹形 DP 題

考慮一顆子樹內部的某種割邊方式,假設其被分為 \(n\) 個連通塊,每個連通塊的權值分別為 \(a_1, a_2, \dots, a_n\),那麼該子樹在這種割邊方式下對答案的貢獻就為 \(\prod_{i = 1}^{n} a_i\)

因此就可以從葉子向根不斷合併,求出每種割邊狀態的值,時間複雜度為 \(O(2^{n - 1}n)\),期望得分 \(8\) 分。

這啟示往樹形 DP 的方向思考。

將每次定下割邊的方法轉變,考慮在 DP 過程中透過將兩個連通塊連線到一起,去遍歷每一種狀態。

這樣,每回溯到一個點:

  1. 遍歷該點的子樹
  2. 把與該點之間存在割邊的連通塊與該點之前所找到的連通塊合併
  3. 每次合併後求出該情況的貢獻(如圖,將藍色連通塊的權值異或在一起,然後計算結果)「KDOI-06-S」題解

實現困難,時間複雜度極高。

因為連通塊對答案的貢獻是 \(\prod_{i = 1} ^{n} a_i\) 的形式,故某子樹除去被合併的連通塊後不同情況產生的貢獻是可以累加的。(答案是 \(a_1 b_1+\dots+a_1 b_n+a_2 b_1+\dots+a_n b_n\),即 \((a_1+\dots+a_n)(b_1+\dots+b_n)\))。

而合併連通塊卻無法這樣最佳化。

對此,有一種方法能夠快速地合併連通塊——拆位。

具體來說,定義 \(f_{u, i, j}\) 表示以 \(u\) 所在的連通塊的權值第 \(i\) 位為 \(j\) 時以 \(u\) 為根節點的子樹除了\(u\) 所在的連通塊其他連通塊的乘積的值,定義 \(g_u\) 表示以 \(u\) 為根節點的子樹對答案的貢獻。

容易得到:$$g_u = \sum_{i = 0}^{63}f_{u, i, 1} \times 2^i$$,即 \(u\) 所在連通塊第 \(i\) 位為 \(1\) 時,所有的割邊方案的貢獻。

故僅需考慮 \(f_{u, i, j}\) 的轉移。

考慮當前遍歷到 \(u\) 的兒子節點 \(v\),則:

  1. 如果不合並,則 \(v\) 的子樹全都與 \(u\) 所在的連通塊無關,那麼 \(g_v\) 全都要乘到 \(f_{u, i, 0}\)
  2. 合併第 \(i\) 位為 \(1\) 的情況,如果連通塊原本為 \(1\) ,則與該子樹中第 \(i\) 位為 \(0\) 的異或後第 \(i\) 位仍然為 \(1\)。否則為與第 \(i\) 位為 \(1\) 的連通塊異或。
  3. \(i\) 位為 \(0\) 則恰好相反。

即:

\[f_{u, i, 0} = f_{u, i, 0} \times g_{v} + f_{v, i, 0} \times f_{u, i, 0} + f_{v, i, 1} \times f_{u, i, 1} \]

\[f_{u, i, 1} = f_{u, i, 1} \times g_v + f_{v, i, 0} \times f_{u, i, 1} + f_{v, i, 1} \times f_{u, i, 0} \]

答案為 \(g_1\)

注意

本題空間較小,動態規劃陣列開 long long 會爆。

點選檢視程式碼
/*
  --------------------------------
  |        code by FRZ_29        |
  |          code  time          |
  |          2024/09/15          |
  |           13:42:20           |
  |             星期天            |
  --------------------------------
                                  */

#include <iostream>
#include <climits>
#include <cstdio>
#include <ctime>
typedef long long LL;

using namespace std;

void RD() {}
template<typename T, typename... U> void RD(T &x, U&... arg) {
    x = 0; int f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    x *= f; RD(arg...);
}

const int N = 5e5 + 5;
const int mod = 998244353;

#define PRINT(x) cout << #x << "=" << x << "\n"
#define LF(i, __l, __r) for (int i = __l; i <= __r; i++)
#define RF(i, __r, __l) for (int i = __r; i >= __l; i--)

int head[N], Next[N << 1], ver[N << 1], tot = 1;
int n, f[N][65][2], g[N];
LL a[N];

void add(int u, int v) {
    ver[++tot] = v;
    Next[tot] = head[u], head[u] = tot;
}

void dfs(int u, int _f) {
    LF(i, 0, 63) f[u][i][a[u] >> i & 1] = 1;
    
    for (int i = head[u]; i; i = Next[i]) {
        int v = ver[i];
        if (v == _f) continue;
        dfs(v, u);

        LF(i, 0, 63) {
            LL t0 = f[u][i][0], t1 = f[u][i][1];
            f[u][i][0] = (t0 * g[v] + t0 * f[v][i][0] + t1 * f[v][i][1]) % mod;
            f[u][i][1] = (t1 * g[v] + t1 * f[v][i][0] + t0 * f[v][i][1]) % mod;
        }
    }

    LF(i, 0, 63) g[u] = (g[u] + (1LL << i) % mod * f[u][i][1]) % mod;
}

int main() {
//    freopen("read.in", "r", stdin);
//    freopen("out.out", "w", stdout);
//    time_t st = clock();
    RD(n);
    LF(i, 1, n) RD(a[i]);
    LF(u, 2, n) {
        int v; RD(v);
        add(u, v), add(v, u);
    }
    dfs(1, 0);
    printf("%d", g[1]);
//    printf("\n%dms", clock() - st);
    return 0;
}

/* ps:FRZ弱爆了 */