樹形DP

RonChen發表於2024-05-21

樹形 DP 即在樹上進行的 DP。

常見的兩種轉移方向:

  • 父節點 \(\rightarrow\) 子節點:如求節點深度,\(dp_u = dp_{fa} + 1\)
  • 子節點 \(\rightarrow\) 父節點:如求子樹大小,\(dp_u = 1 + \sum dp_v\)

習題:P5658 [CSP-S2019] 括號樹

暴力

本題 \(n\) 小的資料點保證為鏈,直接列舉 \(i\),代表從根節點到 \(i\) 號節點。列舉 \([1,i]\) 中的子區間左右端點 \([l,r]\),判斷該子串是否符合括號匹配。

#include <cstdio>
typedef long long LL;
const int N = 500005;
char s[N];
int f[N];
bool check(int l, int r) {
    int left = 0;
    for (int i = l; i <= r; i++) {
        if (s[i] == '(') left++;
        else if (left == 0) return false;
        else left--;
    }
    return left == 0;
}
int main()
{
    int n; scanf("%d%s", &n, s + 1);
    for (int i = 2; i <= n; i++) scanf("%d", &f[i]);
    LL ans = 0;
    for (int i = 2; i <= n; i++) {
        // 1~i
        LL cnt = 0;
        for (int l = 1; l <= i; l++) {
            for (int r = l; r <= i; r++) {
                // [l, r]
                if (check(l, r)) {
                    cnt++;
                }
            }
        }
        ans ^= cnt * i;
    }
    printf("%lld\n", ans);
    return 0;
}

實際得分 \(20\) 分。

特殊性質

考慮資料為鏈但 \(n \le 5 \times 10^5\) 的問題做法,此時可以看作是一個線性序列上的問題。

考慮用 \(dp_i\) 表示以 \(s_i\) 結尾的合法括號串,如果 \(s_i\) 是左括號,顯然 \(dp_i = 0\);而如果 \(s_i\) 是右括號,這時實際上需要找到與該右括號對應匹配的左括號(這個問題可以藉助一個棧來實現),則該左括號到當前的右括號構成了一個合法括號串,而實際上如果這個左括號的前一位是一個合法括號串的結尾,那麼之前的合法括號串拼上剛匹配的這個括號串也是一個合法括號串,因此這時 \(dp_i = dp_{pre-1} + 1\),這裡的 \(pre\) 代表當前右括號匹配的左括號的位置。

題目要求計算的是合法括號子串的數量,因此只需計算 \(dp\) 結果的字首和即為前 \(i\) 個字元形成的字串中合法括號子串的數量。

#include <cstdio>
#include <stack>
using namespace std;
typedef long long LL;
const int N = 500005;
char s[N];
int f[N];
LL dp[N]; // 以s[i]結尾的括號子串數量
LL sum[N]; // 1~i中的括號子串數量,即dp的字首和
int main()
{
    int n; scanf("%d%s", &n, s + 1);
    for (int i = 2; i <= n; i++) scanf("%d", &f[i]);
    stack<int> stk; // 記錄左括號的位置
    LL ans = 0;
    for (int i = 1; i <= n; i++) {
        if (s[i] == '(') {
            stk.push(i);
        } else if (!stk.empty()) {
            int pre = stk.top();
            stk.pop();
            dp[i] = dp[pre - 1] + 1;
        }
        sum[i] = sum[i - 1] + dp[i];
        ans ^= sum[i] * i;
    }
    printf("%lld\n", ans);
    return 0;
}

實際得分 \(55\) 分。

解題思路

把處理鏈的思路轉化到任意樹上。

其中 \(dp\)\(sum\) 的計算方式可以類推過來,只不過鏈上透過“減一”表達上一個位置的方式對應到樹上要變成“父節點”。因此原來的計算式子需要調整一下:

  • \(dp_i = dp_{pre-1} + 1 \rightarrow dp_i = dp_{f_{pre}} + 1\)
  • \(sum_i = sum_{i-1} + dp_i \rightarrow sum_i = sum_{f_i} + dp_i\)

除此以外,還需要解決樹上的括號棧的遞迴與回溯問題。發生回溯後,棧裡的資訊可能會和當前狀態不匹配。比如某個節點(左括號)有多棵子樹,進入其中一棵子樹之後,該子樹中的右括號匹配掉了這個左括號(出棧),而接下來再進入下一棵子樹時這個左括號依然需要在棧中。

因此回溯時,我們要執行當時遞迴時相反的操作。比如,當前節點是右括號,如果此時棧不為空,棧會彈出一個元素以匹配當前右括號。我們可以記錄這個資訊,在最後回溯前把它重新壓入棧中,保持狀態的一致性。

參考程式碼
#include <cstdio>
#include <vector>
#include <stack>
using namespace std;
typedef long long LL;
const int N = 500005;
char s[N];
vector<int> tree[N];
stack<int> stk;
int f[N];
LL dp[N], sum[N];
void dfs(int cur, int fa) {
    int tmp = 0;
    if (s[cur] == '(') {
        stk.push(cur); tmp = -1;
        dp[cur] = 0;
    } else if (stk.empty()) {
        dp[cur] = 0;
    } else {
        tmp = stk.top(); stk.pop(); 
        dp[cur] = dp[f[tmp]] + 1; 
    }
    sum[cur] = sum[fa] + dp[cur];
    for (int to : tree[cur]) dfs(to, cur);
    if (tmp == -1) stk.pop();
    else if (tmp > 0) stk.push(tmp);
}
int main()
{
    int n;
    scanf("%d%s", &n, s + 1);
    for (int i = 2; i <= n; i++) {
        scanf("%d", &f[i]);
        tree[f[i]].push_back(i);
    }
    dfs(1, 0);
    LL ans = 0;
    for (int i = 1; i <= n; i++) ans ^= (sum[i] * i);
    printf("%lld\n", ans);
    return 0;
}

習題:P4084 [USACO17DEC] Barn Painting G

解題思路

\(dp_{u,c}\) 表示以 \(u\) 為根節點的子樹,節點 \(u\) 的顏色為 \(c\) 的方案數,即對於所有初始狀態,\(dp_{u,1} = dp_{u,2} = dp_{u,3} = 1\),如果某個節點被上了指定的顏色,那麼該節點的狀態中另外兩種上色狀態方案數為 \(0\)

對於每個節點,由於不能與子節點顏色相同,則有:

  • \(dp_{u,1} = \prod \limits_{v \in son_u} (dp_{v,2} + dp_{v,3})\)
  • \(dp_{u,2} = \prod \limits_{v \in son_u} (dp_{v,1} + dp_{v,3})\)
  • \(dp_{u,3} = \prod \limits_{v \in son_u} (dp_{v,1} + dp_{v,2})\)
參考程式碼
#include <cstdio>
#include <vector>
using namespace std;
const int N = 100005;
const int MOD = 1000000007;
vector<int> tree[N];
int c[N], dp[N][4];
void dfs(int u, int fa) {
    int ans1 = 1, ans2 = 1, ans3 = 1;
    for (int v : tree[u]) {
        if (v == fa) continue;
        dfs(v, u);
        ans1 = 1ll * (dp[v][2] + dp[v][3]) % MOD * ans1 % MOD;
        ans2 = 1ll * (dp[v][1] + dp[v][3]) % MOD * ans2 % MOD;
        ans3 = 1ll * (dp[v][1] + dp[v][2]) % MOD * ans3 % MOD;
    }
    if (c[u] == 0 || c[u] == 1) dp[u][1] = ans1;
    if (c[u] == 0 || c[u] == 2) dp[u][2] = ans2;
    if (c[u] == 0 || c[u] == 3) dp[u][3] = ans3;
}
int main()
{
    int n, k;
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        tree[x].push_back(y); tree[y].push_back(x);
    }
    while (k--) {
        int b;
        scanf("%d", &b); scanf("%d", &c[b]);
    }
    dfs(1, 0);
    printf("%d\n", ((dp[1][1] + dp[1][2]) % MOD + dp[1][3]) % MOD);
    return 0;
}

習題:P2899 [USACO08JAN] Cell Phone Network G

解題思路

注意本題和 P2016 戰略遊戲 的區別,戰略遊戲是選擇一些點從而覆蓋所有的邊,本題是選擇一些點從而覆蓋所有的點。

在戰略遊戲中,一條邊可能會被兩端的點覆蓋到,因此對於每個點對應的子樹需要設計兩個狀態(選/不選)。類似地,在本題中,我們可以要分三種狀態:

  • \(dp_{u,0}\) 表示 \(u\) 被自己覆蓋的情況下對應子樹的最少訊號塔數量
  • \(dp_{u,1}\) 表示 \(u\) 被子節點覆蓋的情況下對應子樹的最少訊號塔數量
  • \(dp_{u,2}\) 表示 \(u\) 被父節點覆蓋的情況下對應子樹的最少訊號塔數量

則有狀態轉移:

  • \(dp_{u,0} = \sum \limits_{v \in son_u} \min \{dp_{v,0},dp_{v,1},dp_{v,2}\}\),因為 \(u\) 處自己放置了訊號塔,因此子節點處放或不放都可以
  • \(dp_{u,1} = dp_{v',0} + \sum \limits_{v \in son_u \land v \ne v'} \min \{ dp_{v,0},dp_{v,1} \}\),此時至少要有一個子節點放置訊號塔,其他可放可不放,因此 \(v'\) 應該是所有子節點 \(v\)\(dp_{v,0} - \min \{ dp_{v,0}, dp_{v,1} \}\) 最小的那個子節點;注意若 \(u\) 沒有子樹即 \(u\) 為葉子節點,此時 \(dp_{u,1}=1\)
  • \(dp_{u,2} = \sum \limits_{v \in son_u} \min \{ dp_{v,0}, dp_{v,1} \}\),因為本節點處不放,靠父節點放置來覆蓋,所以子節點中除了狀態 \(2\) 以外都可以
參考程式碼
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 10005;
vector<int> tree[N];
int dp[N][3];
// dp[u][0]:u處放置
// dp[u][1]:u處依賴子節點放置
// dp[u][2]:u處依賴父節點放置
void dfs(int u, int fa) {
    dp[u][0] = 1;
    int best = -1;
    for (int v : tree[u]) {
        if (v == fa) continue;
        dfs(v, u);
        dp[u][0] += min(min(dp[v][0], dp[v][1]), dp[v][2]);
        dp[u][2] += min(dp[v][0], dp[v][1]);
        dp[u][1] += min(dp[v][0], dp[v][1]);
        // 尋找必須要放置的那個子節點
        int cur_diff = dp[v][0] - min(dp[v][0], dp[v][1]);
        int best_diff = dp[best][0] - min(dp[best][0], dp[best][1]);
        if (best == -1 || cur_diff < best_diff)
            best = v;
    }
    if (best != -1) {
        // 至少要在一個子節點處放置
        dp[u][1] += dp[best][0] - min(dp[best][0], dp[best][1]);
    } else {
        dp[u][1] = 1; // 沒有子樹,必須放置
    }
}
int main()
{
    int n; scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int a, b; scanf("%d%d", &a, &b);
        tree[a].push_back(b);
        tree[b].push_back(a);
    }
    dfs(1, 0);
    printf("%d\n", min(dp[1][0], dp[1][1]));
    return 0;
}

習題:P3574 [POI2014] FAR-FarmCraft

解題思路

\(dp_u\) 表示假如在 \(0\) 時刻到達 \(u\),它的子樹被安裝好的時間。

發現對下面子樹的遍歷順序會影響最終結果,考慮這個順序,假設針對 \(u\) 的某兩棵子樹 \(v_1\)\(v_2\)

  • 假設先走 \(v_1\) 再走 \(v_2\),則此時可能的完成時間是 \(\max (1 + dp_{v_1}, 2 \times sz_{v_1} + 1 + dp_{v_2})\),前者表示 \(v_1\) 那棵子樹完成時間更晚,後者表示 \(v_2\) 那棵子樹完成時間更晚,此時要先走完 \(v_1\) 子樹再走到 \(v_2\) 才能加 \(dp_{v_2}\)
  • 假設先走 \(v_2\) 再走 \(v_1\),則此時可能的完成時間是 \(\max (1 + dp_{v_2}, 2 \times sz_{v_2} + 1 + dp_{v_1})\)

顯然我們希望 \(v_1\) 子樹和 \(v_2\) 子樹形成更好的遍歷順序,考慮按這上面的式子對子樹排序。

注意,用上面的式子比大小對子節點排序需要證明 \(\max (1 + dp_{v_1}, 2 \times sz_{v_1} + 1 + dp_{v_2}) < \max (1 + dp_{v_2}, 2 \times sz_{v_2} + 1 + dp_{v_1})\) 這個式子具有傳遞性。這是可以證明的:假設小於號前面的式子取到第一項,此時這個式子必然滿足,因為小於號後面式子的第二項必然比它大,傳遞性顯然成立;假如小於號前面的式子取到第二項,此時相當於需要 \(2 \times sz_{v_1} + dp_{v_2} < 2 \times sz_{v_2} + dp_{v_1}\),這個式子經過移項可以使得小於號左邊只和 \(v_1\) 有關,右邊只和 \(v_2\) 有關,因此傳遞性得證。

所以我們可以按這種方式對子樹排序,按照子樹的遍歷依次更新 \(dp_u\),這裡的轉移式是 \(2 \times sum + 1 + dp_v\),其中 \(sum\) 代表在 \(v\) 這棵子樹之前的子樹大小總和。

注意最後答案是 \(dp_1\)\(2 \times (n-1) + c_1\) 的較大值,因為題目要求走一圈後回到點 \(1\) 才能開始給 \(1\) 裝軟體。

參考程式碼
#include <cstdio>
#include <vector>
#include <algorithm>
using std::vector;
using std::sort;
using std::max;
const int N = 5e5 + 5;
vector<int> tree[N];
int c[N], sz[N], n, dp[N];
void dfs(int u, int fa) {
    dp[u] = c[u]; sz[u] = 1;
    for (int v : tree[u]) {
        if (v == fa) continue;
        dfs(v, u);
        sz[u] += sz[v];
    }
    sort(tree[u].begin(), tree[u].end(), [](int i, int j) {
        int i_before_j = max(1 + dp[i], 2 * sz[i] + 1 + dp[j]);
        int j_before_i = max(1 + dp[j], 2 * sz[j] + 1 + dp[i]);
        return i_before_j < j_before_i;
    });
    int sum = 0;
    for (int v : tree[u]) {
        if (v == fa) continue;
        dp[u] = max(dp[u], 2 * sum + 1 + dp[v]);
        sum += sz[v];
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &c[i]);
    }
    for (int i = 1; i < n; i++) {
        int a, b; scanf("%d%d", &a, &b);
        tree[a].push_back(b); tree[b].push_back(a);
    }
    dfs(1, 0);
    // 1號點要等回來才能裝,所以要考慮2*(n-1)+c[1]
    printf("%d\n", max(dp[1], 2 * (n - 1) + c[1])); 
    return 0;
}

換根 DP

樹形 DP 中的換根 DP 問題又被稱為二次掃描,通常需要求以每個點為根時某個式子的答案。

這一類問題通常需要遍歷兩次樹,第一次遍歷先求出以某個點(如 \(1\))為根時的答案,在第二次遍歷時考慮由根為 \(u\) 轉化為根為 \(v\) 時答案的變化(換根)。這個變化往往分為兩部分,\(v\) 子樹外的點到 \(v\) 相比於到 \(u\) 會增加一條邊,而 \(v\) 子樹內的點到 \(v\) 相比於到 \(u\) 會減少一條邊。所以往往在第一次遍歷時可以順帶求出一些子樹資訊,利用這些資訊輔助第二次遍歷時的換根操作。

例題:P3478 [POI2008] STA-Station

給定一棵 \(n\) 個節點的樹,求出一個節點,使得以這個節點為根時,所有節點的深度之和最大。
資料範圍:\(n \le 10^6\)

隨便選擇一個節點 \(u\) 作為根節點,遍歷整棵樹,則得到了以 \(u\) 為根節點時的深度之和。

\(dp_u\) 表示以 \(u\) 為根時,所有節點的深度之和。設 \(v\) 為當前節點的某個子節點,考慮“換根”,即以 \(u\) 為根轉移到以 \(v\) 為根,顯然在換根的過程中,以 \(v\) 為根會導致每個節點的深度都產生改變。具體表現為:

  • 所有在 \(v\) 的子樹上的節點深度都減少了一,那麼總深度和就減少了 \(sz_v\),這裡用 \(sz_i\) 表示以 \(i\) 為根的子樹中的節點個數。
  • 所有不在 \(v\) 的子樹上的節點深度都增加了一,那麼總深度和就增加了 \(n - sz_v\)

根據這兩個條件就可以推出狀態轉移方程:\(dp_v = dp_u + n - 2 \times sz_v\),因此可以在第一次遍歷時順便計算一下 \(sz\),第二次遍歷時用狀態轉移方程計算出最終的答案。

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 1000005;
vector<int> tree[N];
int sz[N], res, n;
LL ans;
void dfs(int cur, int fa, int depth) {
    ans += depth;
    sz[cur] = 1;
    for (int to : tree[cur]) {
        if (to == fa) continue;
        dfs(to, cur, depth + 1);
        sz[cur] += sz[to];
    }
}
void solve(int cur, int fa, LL sum) {
    for (int to : tree[cur]) {
        if (to == fa) continue;
        LL tmp = sum + n - 2 * sz[to];
        if (tmp > ans) {
            ans = tmp; res = to;
        }
        solve(to, cur, tmp);
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        tree[u].push_back(v); tree[v].push_back(u);
    }
    dfs(1, 0, 1);
    res = 1;
    solve(1, 0, ans);
    printf("%d\n", res);
    return 0;
}

習題:P2986 [USACO10MAR] Great Cow Gathering G

解題思路

與上一題類似,只不過這題換根時的變化量是點權和(即牛棚中奶牛的數量)的變化乘以邊權。

參考程式碼
#include <cstdio>
#include <vector>
using namespace std;
typedef long long LL;
const int N = 100005;
int n, c[N];
LL ans;
struct Edge {
    int to, l;
};
vector<Edge> tree[N];
void dfs(int cur, int fa, int depth) {
    ans += 1ll * depth * c[cur];
    for (Edge e : tree[cur]) {
        if (e.to == fa) continue;
        dfs(e.to, cur, depth + e.l);
        c[cur] += c[e.to];
    }
}
void solve(int cur, int fa, LL sum) {
    for (Edge e : tree[cur]) {
        if (e.to == fa) continue;
        LL tmp = sum + 1ll * (c[1] - 2 * c[e.to]) * e.l;
        ans = min(ans, tmp);
        solve(e.to, cur, tmp);
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
    for (int i = 1; i < n; i++) {
        int a, b, l; scanf("%d%d%d", &a, &b, &l);
        tree[a].push_back({b, l}); tree[b].push_back({a, l});
    }
    dfs(1, 0, 0);
    solve(1, 0, ans);
    printf("%lld\n", ans);
    return 0;
}

習題:P3047 [USACO12FEB] Nearby Cows G

解題思路

可以先對樹做一次遍歷得到每個節點對應的子樹下距離子樹根節點距離 \(0 \sim x\) 之間的點權和,

然後考慮每個點距離 \(k\) 之內的點權和。子樹下的點權和在第一次遍歷時已經計算完成,因此還需要計算的是在該點子樹外的距離 \(k\) 以內的部分,而這個部分可以透過對該點上方最多 \(k\) 個祖先節點的處理,如下圖所示。

image

參考程式碼
#include <cstdio>
#include <vector>
using std::vector;
const int N = 100005;
const int K = 25;
vector<int> tree[N];
int n, k, sum[N][K], c[N], ans[N];
void dfs(int u, int fa) {
    for (int v : tree[u]) {
        if (v == fa) continue;
        dfs(v, u);
        for (int i = 1; i <= k; i++) {
            sum[u][i] += sum[v][i - 1];
        }
    }
    sum[u][0] = c[u];
}
void calc(int u, int fa, vector<int> pre) {
    int dis = pre.size();
    pre.push_back(u);
    // u的子樹內的距離範圍內的點權和
    ans[u] = sum[u][k];
    // 計算u的子樹外的距離範圍內的點權和
    for (int i = 0; i + 1 < pre.size(); i++) {
        int cur = pre[i], nxt = pre[i + 1];
        // 對於邊cur->nxt
        ans[u] += sum[cur][k - dis]; // 加上cur子樹下的距離內點權和
        if (k - dis > 0) ans[u] -= sum[nxt][k - dis - 1]; // 減去nxt子樹下剛剛被重複計算的部分
        dis--;
    }
    vector<int> path;
    if (pre.size() == k + 1) {
        // pre[0]即將超出下面的點的距離範圍k,要被淘汰
        for (int i = 1; i < pre.size(); i++) path.push_back(pre[i]);
    } else path = pre;
    for (int v : tree[u]) {
        if (v == fa) continue;
        calc(v, u, path);
    }
}
int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        tree[u].push_back(v); tree[v].push_back(u);
    }
    for (int i = 1; i <= n; i++) scanf("%d", &c[i]);
    dfs(1, 0);
    // 生成字首和
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= k; j++) sum[i][j] += sum[i][j - 1];
    }
    vector<int> tmp;
    calc(1, 0, tmp);
    for (int i = 1; i <= n; i++) printf("%d\n", ans[i]);
    return 0;
}

樹上揹包

樹上的揹包問題,也就是揹包問題與樹形 DP 的結合。

例題:P2014 [CTSC1997] 選課

\(n\) 門課程,第 \(i\) 門課程的學分為 \(a_i\),每門課程有零門或一門先修課,有先修課的課程需要先學完其先修課,才能學習該課程。一位學生要學習 \(m\) 門課程,求其能獲得的最多學分數。
資料範圍:\(n,m \le 300\)

由於每門課最多隻有一門先修課,與有根樹中一個點最多隻有一個父親節點的特點類似。可以利用這個性質來建樹,從而所有課程形成了一個森林結構。為了方便起見,可以新增一門 \(0\) 學分的課程(編號為 \(0\)),作為所有無先修課課程的先修課,這樣原本的森林就變成了一棵以 \(0\) 號節點為根的樹。

\(dp_{u,i,j}\) 表示以 \(u\) 為根的子樹中,已經遍歷了 \(u\) 號點的前 \(i\) 棵子樹,選了 \(j\) 門課程的最大學分。

轉移的過程結合了樹形 DP 和揹包問題的特點,列舉點 \(u\) 的每個子節點 \(v\),同時列舉以 \(v\) 為根的子樹選了幾門課程,將子樹的結果合併到 \(u\) 上。

將點 \(x\) 的子節點個數記為 \(s_x\),以 \(x\) 為根的子樹大小為 \(sz_x\),則有狀態轉移方程:\(dp_{u,i,j} = \max \{ dp_{u,i-1,j-k} + dp_{v,s_v,k} \}\),注意有一些狀態是無效的,比如 \(k>j\) 或是 \(k>sz_v\) 時。

第二維可以透過滾動陣列最佳化掉,此時需要倒序列舉 \(j\) 的值,同 0-1 揹包問題。

該做法的時間複雜度為 \(O(nm)\),證明見 子樹合併揹包型別的dp的複雜度證明

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 305;
vector<int> tree[N];
int s[N], dp[N][N], n, m;
int dfs(int u) {
    int sz_u = 1;
    dp[u][1] = s[u];
    for (int v : tree[u]) {
        int sz_v = dfs(v);
        for (int i = min(sz_u, m + 1); i > 0; i--)
            for (int j = 1; j <= sz_v && i + j <= m + 1; j++)
                dp[u][i + j] = max(dp[u][i + j], dp[u][i] + dp[v][j]);
        sz_u += sz_v;
    }
    return sz_u;
}
int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        int k;
        scanf("%d%d", &k, &s[i]);
        tree[k].push_back(i);
    }
    dfs(0);
    printf("%d\n", dp[0][m + 1]);
    return 0;
}

習題:P3177 [HAOI2015] 樹上染色

解題思路

顯然設 \(dp_{u,i}\) 代表以 \(u\) 為根節點的子樹,將其中 \(i\) 個點染成黑色的狀態。

但是這個值存什麼呢?如果直接表示子樹內黑點、白點間的收益,這個狀態沒有辦法轉移,因為子樹內最大化收益的染色方案被合併上去後未必是最優的方案,也就是有後效性。

考慮每條邊對最終答案的貢獻,如果一條邊的兩側有一對黑點或白點,則這條邊對這兩個點構成的路徑是有貢獻的。也就是說,一條邊對總答案的貢獻次數等於邊的兩側同色點個數的乘積。而子樹內每條邊對總答案的貢獻這個狀態值在子樹合併過程中是可以向上傳遞的。

因此 \(dp_{u,i}\) 代表以 \(u\) 為根節點的子樹中有 \(i\) 個點被染成黑色後子樹內每一條邊對總答案的貢獻,類似樹上揹包,當要合併某個子節點 \(v\) 對應的子樹時,列舉子樹中的黑點數量 \(j\),則對應的轉移為 \(dp_{u,i-j} + dp_{v,j} + w(u,v) \times c(u,v)\),其中 \(w(u,v)\) 代表邊權,\(c(u,v)\) 代表 u-v 這條邊對總答案的貢獻,而 \(c(u,v) = j \times (k - j) + (sz_v - j) \times (n - sz_v - (k - j))\)

參考程式碼
#include <cstdio>
#include <utility>
#include <vector>
#include <algorithm>
using std::pair;
using std::vector;
using std::min;
using std::max;
typedef long long LL;
typedef pair<int, int> PII;
const int N = 2005;
int n, k;
vector<PII> tree[N];
LL dp[N][N]; // dp[u][i] u的子樹內染i個黑點時邊對答案的總貢獻
int dfs(int u, int fa) {
    int sz_u = 1;
    for (PII p : tree[u]) {
        int v = p.first, w = p.second;
        if (v == fa) continue;
        int sz_v = dfs(v, u);
        for (int i = min(k, sz_u); i >= 0; i--) {
            for (int j = 1; j <= sz_v && i + j <= k; j++) {
                LL black = 1ll * w * j * (k - j);
                LL white = 1ll * w * (sz_v - j) * (n - sz_v - (k - j));
                dp[u][i + j] = max(dp[u][i + j], dp[u][i] + dp[v][j] + black + white);
            }
            // 若將j=0放在迴圈中,則會立馬更新dp[u][i]
            // 導致接下來計算dp[u][i+j]時用到的dp[u][i]已經不是之前的值了
            LL white = 1ll * w * sz_v * (n - sz_v - k);
            dp[u][i] += dp[v][0] + white;
        }
        sz_u += sz_v;
    }
    return sz_u;
}
int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++) {
        int u, v, w; scanf("%d%d%d", &u, &v, &w);
        tree[u].push_back({v, w}); tree[v].push_back({u, w});
    }
    dfs(1, 0);
    printf("%lld\n", dp[1][k]);
    return 0;
}

相關文章