換根 DP

RonChen發表於2024-11-09

樹形 DP 中的換根 DP 問題又被稱為二次掃描,通常需要求以每個點為根時某個式子的答案。

這一類問題通常需要遍歷兩次樹,第一次遍歷先求出以某個點(如 \(1\))為根時的答案,在第二次遍歷時考慮由根為 \(u\) 轉化為根為 \(v\) 時答案的變化(換根)。這個變化往往分為兩部分,\(v\) 子樹外的點到 \(v\) 相比於到 \(u\) 會增加一條邊,而 \(v\) 子樹內的點到 \(v\) 相比於到 \(u\) 會減少一條邊。

所以往往在第一次遍歷時可以順帶求出一些子樹資訊,利用這些資訊輔助第二次遍歷時的換根操作。

經典例題:求對於每個點而言,其他點到這個點的距離之和。

例題:P3478 [POI2008] STA-Station

給定一棵 \(n\) 個節點的樹,求出一個節點,使得以這個節點為根時,所有節點的深度之和最大。
資料範圍:\(n \le 10^6\)

分析:隨便選擇一個節點 \(u\) 作為根節點,遍歷整棵樹,則得到了以 \(u\) 為根節點時的深度之和。

\(dp_u\) 表示以 \(u\) 為根時,所有節點的深度之和。設 \(v\) 為當前節點的某個子節點,考慮“換根”,即以 \(u\) 為根轉移到以 \(v\) 為根,顯然在換根的過程中,以 \(v\) 為根會導致每個節點的深度都產生改變。具體表現為:

  • 所有在 \(v\) 的子樹上的節點深度都減少了一,那麼總深度和就減少了 \(sz_v\),這裡用 \(sz_i\) 表示以 \(i\) 為根的子樹中的節點個數。
  • 所有不在 \(v\) 的子樹上的節點深度都增加了一,那麼總深度和就增加了 \(n - sz_v\)

根據這兩個條件就可以推出狀態轉移方程:\(dp_v = dp_u + n - 2 \times sz_v\),因此可以在第一次遍歷時順便計算一下 \(sz\),第二次遍歷時用狀態轉移方程計算出最終的答案。

參考程式碼
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 1000005;
vector<int> tree[N];
int sz[N], res, n;
LL ans;
void dfs(int cur, int fa, int depth) {
    ans += depth;
    sz[cur] = 1;
    for (int to : tree[cur]) {
        if (to == fa) continue;
        dfs(to, cur, depth + 1);
        sz[cur] += sz[to];
    }
}
void solve(int cur, int fa, LL sum) {
    for (int to : tree[cur]) {
        if (to == fa) continue;
        LL tmp = sum + n - 2 * sz[to];
        if (tmp > ans) {
            ans = tmp; res = to;
        }
        solve(to, cur, tmp);
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        tree[u].push_back(v); tree[v].push_back(u);
    }
    dfs(1, 0, 1);
    res = 1;
    solve(1, 0, ans);
    printf("%d\n", res);
    return 0;
}

習題:P2986 [USACO10MAR] Great Cow Gathering G

解題思路

與上一題類似,只不過這題換根時的變化量是點權和(即牛棚中奶牛的數量)的變化乘以邊權。

參考程式碼
#include <cstdio>
#include <vector>
using namespace std;
typedef long long LL;
const int N = 100005;
int n, c[N];
LL ans;
struct Edge {
    int to, l;
};
vector<Edge> tree[N];
void dfs(int cur, int fa, int depth) {
    ans += 1ll * depth * c[cur];
    for (Edge e : tree[cur]) {
        if (e.to == fa) continue;
        dfs(e.to, cur, depth + e.l);
        c[cur] += c[e.to];
    }
}
void solve(int cur, int fa, LL sum) {
    for (Edge e : tree[cur]) {
        if (e.to == fa) continue;
        LL tmp = sum + 1ll * (c[1] - 2 * c[e.to]) * e.l;
        ans = min(ans, tmp);
        solve(e.to, cur, tmp);
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
    for (int i = 1; i < n; i++) {
        int a, b, l; scanf("%d%d%d", &a, &b, &l);
        tree[a].push_back({b, l}); tree[b].push_back({a, l});
    }
    dfs(1, 0, 0);
    solve(1, 0, ans);
    printf("%lld\n", ans);
    return 0;
}

例題:P3047 [USACO12FEB] Nearby Cows G

分析:可以先對樹做一次遍歷得到每個節點對應的子樹下距離子樹根節點距離 \(0 \sim x\) 之間的點權和,

然後考慮每個點距離 \(k\) 之內的點權和。子樹下的點權和在第一次遍歷時已經計算完成,因此還需要計算的是在該點子樹外的距離 \(k\) 以內的部分,而這個部分可以透過對該點上方最多 \(k\) 個祖先節點的處理,如下圖所示。

image

參考程式碼
#include <cstdio>
#include <vector>
using std::vector;
const int N = 100005;
const int K = 25;
vector<int> tree[N];
int n, k, sum[N][K], c[N], ans[N];
void dfs(int u, int fa) {
    for (int v : tree[u]) {
        if (v == fa) continue;
        dfs(v, u);
        for (int i = 1; i <= k; i++) {
            sum[u][i] += sum[v][i - 1];
        }
    }
    sum[u][0] = c[u];
}
void calc(int u, int fa, vector<int> pre) {
    int dis = pre.size();
    pre.push_back(u);
    // u的子樹內的距離範圍內的點權和
    ans[u] = sum[u][k];
    // 計算u的子樹外的距離範圍內的點權和
    for (int i = 0; i + 1 < pre.size(); i++) {
        int cur = pre[i], nxt = pre[i + 1];
        // 對於邊cur->nxt
        ans[u] += sum[cur][k - dis]; // 加上cur子樹下的距離內點權和
        if (k - dis > 0) ans[u] -= sum[nxt][k - dis - 1]; // 減去nxt子樹下剛剛被重複計算的部分
        dis--;
    }
    vector<int> path;
    if (pre.size() == k + 1) {
        // pre[0]即將超出下面的點的距離範圍k,要被淘汰
        for (int i = 1; i < pre.size(); i++) path.push_back(pre[i]);
    } else path = pre;
    for (int v : tree[u]) {
        if (v == fa) continue;
        calc(v, u, path);
    }
}
int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        tree[u].push_back(v); tree[v].push_back(u);
    }
    for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
    dfs(1, 0);
    // 生成字首和
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= k; j++) sum[i][j] += sum[i][j - 1];
    }
    vector<int> tmp;
    calc(1, 0, tmp);
    for (int i = 1; i <= n; i++) printf("%d\n", ans[i]);
    return 0;
}