樹上數顏色

onlyblues發表於2024-04-03

樹上數顏色

題目描述

給一棵根為 $1$ 的樹,每次詢問子樹顏色種類數。

輸入格式

第一行一個整數 $n$,表示樹的結點數。

接下來 $n-1$ 行,每行一條邊。

接下來一行 $n$ 個數,表示每個結點的顏色 $c[i]$。

接下來一個數 $m$,表示詢問數。

接下來 $m$ 行表示詢問的子樹。

輸出格式

對於每個詢問,輸出該子樹顏色數。

樣例 #1

樣例輸入 #1

5
1 2
1 3
2 4
2 5
1 2 2 3 3
5
1
2
3
4
5

樣例輸出 #1

3
2
1
1
1

提示

對於前三組資料,有 $1\leq m,c[i]\leq n\leq 100$。

而對於所有資料,有$1\leq m,c[i]\leq n\leq 10^5$。

解題思路

  今天學了 dsu on tree,大概記錄一下原理。

  dsu on tree 主要用於解決靜態(沒有修改操作)的子樹問題,名字上雖然有並查集(dsu),但實際上與並查集或啟發式合併沒有太大的關係,反而用到了樹鏈剖分中重兒子的思想。對於節點 $u$,其重兒子就是 $u$ 的所有子節點中子樹最大的子節點。如果有多個子樹最大的子節點則任取其一。如果沒有子節點,則無重兒子。對應的輕兒子就是除重兒子外剩餘的所有子節點。

  大部分 dsu on tree 的做法都是透過 dfs 遍歷每個節點 $u$ 並進行以下 $3$ 個步驟:

  1. 先對所有輕兒子進行 dfs 求答案,但不記錄 dfs 過程中每個節點對 $u$ 的貢獻。
  2. 對其重兒子進行 dfs 求答案,並記錄 dfs 過程中每個節點對 $u$ 的貢獻。
  3. 再次對所有輕兒子進行 dfs,記錄 dfs 過程中每個節點對 $u$ 的貢獻,從而求出節點 $u$ 的答案。

  在這道題目中,我們用 $\text{cnt}$ 陣列來記錄顏色的出現次數,並用 $s$ 來維護 $\text{cnt}$ 中出現了多少種不同的顏色。

  看上去第 $3$ 步完全可以合併到第 $1$ 步中,實際上這是不對的,這會導致 $\text{cnt}$ 陣列被重複使用,即其他子樹的記錄會影響到當前子樹的答案。而如果對每個節點都開一個 $\text{cnt}$ 陣列,則會導致 $O(n^2)$ 的空間複雜度。

  dsu on tree 的時間複雜度是 $O(n \log{n})$,大概的解釋是樹中任意節點到根的路徑中的輕邊數量不超過 $\log{n}$ 條,意味著在 dfs 的過程中每個節點會被遍歷 $O(\log{n})$ 次,因此所有節點被遍歷的次數就是 $O(n \log{n})$。具體證明參見:樹上啟發式合併

  下面大概講一下本題的程式碼實現。

  首先先透過 dfs 求出每個節點 $u$ 的子節點的子樹大小,並選擇子樹大小最大的子節點作為 $u$ 的重兒子 $\text{son}[u]$,時間複雜度為 $O(n)$。

  然後再 dfs 進行 dsu on tree,對於每個節點都按照上面描述的 $3$ 個步驟進行。其中記錄貢獻的部分就是對 $u$ 的所有輕兒子執行另一個 dfs,並將每個節點的顏色記錄到 $\text{cnt}$ 中。最後如果 $u$ 是其父節點的輕兒子,還需要清除 $\text{cnt}$ 中的所有記錄。

  AC 程式碼如下,時間複雜度為 $O(n \log{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 1e5 + 5, M = N * 2;

int a[N];
int h[N], e[M], ne[M], idx;
int sz[N], son[N];
int ans[N], cnt[N], s;

void add(int u, int v) {
    e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}

void dfs(int u, int p) {
    sz[u] = 1;
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p) continue;
        dfs(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}

void modify(int u, int p, int c, int pson) {
    cnt[a[u]] += c;
    if (cnt[a[u]] == 1 && c == 1) s++;
    if (cnt[a[u]] == 0 && c == -1) s--;
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p || v == pson) continue;
        modify(v, u, c, pson);
    }
}

void dfs(int u, int p, int keep) {
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p || v == son[u]) continue;
        dfs(v, u, 0);
    }
    if (son[u]) dfs(son[u], u, 1);
    modify(u, p, 1, son[u]);
    ans[u] = s;
    if (!keep) modify(u, p, -1, -1);
}

int main() {
    int n, m;
    scanf("%d", &n);
    memset(h, -1, sizeof(h));
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        add(u, v), add(v, u);
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
    }
    dfs(1, -1);
    dfs(1, -1, 0);
    scanf("%d", &m);
    while (m--) {
        int x;
        scanf("%d", &x);
        printf("%d\n", ans[x]);
    }
    
    return 0;
}

  這題還可以用啟發式合併來做(貌似大部分 dsu on tree 的題都可以用啟發式合併來實現)。

  對於每個節點 $u$ 都開一個 std::set<int> 用來儲存子樹 $u$ 中包含的不同顏色,表示為 $\text{st}[u]$。透過 dfs 求出 $u$ 的每個子節點 $v$ 的 $\text{st}[v]$,並將 $\text{st}[v]$ 的元素合併到 $\text{st}[u]$ 中。如果 $\text{st}[v]$ 的大小不超過 $\text{st}[u]$,則直接合並即可。否則需要將兩個集合進行互換,再將 $\text{st}[v]$ 合併到 $\text{st}[u]$ 中。這就是啟發式合併。

  用 std::set<int> 實現的時間複雜度為 $O(n \log^2{n})$,改為 std::unordered_set<int> 的話就是 $O(n \log{n})$,不過任意被卡雜湊函式還是建議用 std::set<int>

  AC 程式碼如下,時間複雜度為 $O(n \log^2{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 1e5 + 5, M = N * 2;

int a[N];
int h[N], e[M], ne[M], idx;
set<int> st[N];
int ans[N];

void add(int u, int v) {
    e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}

void dfs(int u, int p) {
    st[u].insert(a[u]);
    for (int i = h[u]; i != -1; i = ne[i]) {
        int v = e[i];
        if (v == p) continue;
        dfs(v, u);
        if (st[v].size() > st[u].size()) st[u].swap(st[v]);
        st[u].insert(st[v].begin(), st[v].end());
        st[v].clear();
    }
    ans[u] = st[u].size();
}

int main() {
    int n, m;
    scanf("%d", &n);
    memset(h, -1, sizeof(h));
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        scanf("%d %d", &u, &v);
        add(u, v), add(v, u);
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
    }
    dfs(1, -1);
    scanf("%d", &m);
    while (m--) {
        int x;
        scanf("%d", &x);
        printf("%d\n", ans[x]);
    }
    
    return 0;
}

參考資料

  樹上啟發式合併:https://oi-wiki.org/graph/dsu-on-tree/

相關文章