樹上字首和與差分

RonChen發表於2024-06-10

樹上字首和

\(sum_i\) 表示根節點到節點 \(i\) 的權值總和。
則有:

  • 對於點權,\(x,y\) 路徑上的和為 \(sum_x + sum_y - sum_{lca} - sum_{fa_{lca}}\)
  • 對於邊權,\(x,y\) 路徑上的和為 \(sum_x + sum_y - 2 \times sum_{lca}\)

習題:P4427 [BJOI2018] 求和

解題思路

預處理出 \(sum_{i,k}\) 表示根節點到節點 \(i\) 的深度的 \(k\) 次方和,這個過程的時間複雜度為 \(O(nk)\),後面即為樹上點權字首和問題。

參考程式碼
#include <cstdio>
#include <vector>
#include <algorithm>
using std::swap;
using std::vector;
const int N = 3e5 + 5;
const int K = 55;
const int LOG = 19;
const int MOD = 998244353;
vector<int> tree[N];
int fa[N][LOG], depth[N], sum[N][K];
void dfs(int u, int pre) {
    depth[u] = depth[pre] + 1;
    int d = 1;
    for (int i = 0; i < K; i++) {
        sum[u][i] = (sum[pre][i] + d) % MOD;
        d = 1ll * d * depth[u] % MOD;
    } 
    fa[u][0] = pre;
    for (int v : tree[u]) {
        if (v == pre) continue;
        dfs(v, u);
    }
}
int lca(int x, int y) {
    if (depth[x] < depth[y]) swap(x, y);
    int delta = depth[x] - depth[y];
    for (int i = LOG - 1; i >= 0; i--)
        if (delta & (1 << i)) x = fa[x][i];
    if (x == y) return x;
    for (int i = LOG - 1; i >= 0; i--) {
        if (fa[x][i] != fa[y][i]) {
            x = fa[x][i]; y = fa[y][i];
        }
    }
    return fa[x][0];
}
int main()
{
    int n; scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int x, y; scanf("%d%d", &x, &y);
        tree[x].push_back(y);
        tree[y].push_back(x);
    }
    depth[0] = -1;
    dfs(1, 0);
    for (int i = 1; i < LOG; i++) {
        for (int j = 1; j <= n; j++) fa[j][i] = fa[fa[j][i - 1]][i - 1];
    }
    int m; scanf("%d", &m);
    while (m--) {
        int i, j, k; scanf("%d%d%d", &i, &j, &k);
        int lca_ij = lca(i, j), f = fa[lca_ij][0];
        int ans1 = (sum[i][k] + MOD - sum[f][k]) % MOD;
        int ans2 = (sum[j][k] + MOD - sum[lca_ij][k]) % MOD;
        printf("%d\n", (ans1 + ans2) % MOD);
    }
    return 0;
}

樹上差分

樹上差分可以理解為對樹上的某一段路徑進行差分操作,這裡的路徑可以類比一維陣列的區間進行理解。例如在對樹上的一些路徑進行頻繁操作,並且詢問某條邊或者某個點在經過操作後的值的時候,就可以運用樹上差分思想。

點差分

例題:P3128 [USACO15DEC] Max Flow P

問題描述:有 \(n\) 個節點,用 \(n-1\) 條邊連線,所有節點都連通。給出 \(k\) 條路徑,第 \(i\) 條路徑為節點 \(s_i\)\(t_i\)。每給出一條路徑,路徑上所有節點的權值加 \(1\)。輸出最大權值點的權值。
資料範圍:\(2 \le n \le 50000, 1 \le k \le 100000\)

樹上兩點 \(u,v\) 的路徑指的是最短路徑。可以把 \(u \rightarrow v\) 的路徑分為兩個部分:\(u \rightarrow LCA(u,v)\)\(LCA(u,v) \rightarrow v\)

先考慮簡單的思路。首先對每條路徑求 LCA,分別以 \(u\)\(v\) 為起點到 LCA,把路徑上每個節點的權值加 \(1\);然後對所有路徑進行類似操作。把路徑上每個節點加 \(1\) 的操作的複雜度為 \(O(n)\),共 \(k\) 次操作,會超時。

本題的關鍵是如何記錄路徑上每個節點的修改。顯然,如果真的對每個節點都記錄修改,肯定會超時。我們可以利用差分,因為差分的用途是“把區間問題轉換為斷電問題”,適用這種情況。

給定陣列 \(a\),定義差分陣列 \(D[k]=a[k]-a[k-1]\),即陣列相鄰元素的差。

從差分陣列的定義可以推出:\(a[k]=D[1]+D[2]+ \cdots + D[k] = \sum\limits_{i=1}^{k} D[i]\)

這個式子描述了 \(a\)\(D\) 的關係,即“差分是字首和的逆運算” ,它把求 \(a[k]\) 轉換為求 \(D\) 的字首和。

對於區間 \([L,R]\) 的修改問題,比如把區間內每個元素都加上 \(d\),則可以對區間的兩個端點 \(L\)\(R+1\) 做以下操作:

  1. \(D[L]\) 加上 \(d\)
  2. \(D[R+1]\) 減去 \(d\)

image

\(D\) 求字首和,則可得到 \(a\) 陣列,以上的更新相當於:

  1. \(1 \le x < L\)\(a[x]\) 不變;
  2. \(L \le x \le R\)\(a[x]\) 增加了 \(d\)
  3. \(R < x \le N\)\(a[x]\) 不變,因為被 \(D[R+1]\) 中減去的 \(d\) 抵消了。

利用差分能夠把區間修改問題轉換為只用端點做記錄。如果不用差分陣列,區間內每個元素都需要修改,時間複雜度為 \(O(n)\);轉換為只修改兩個端點後,時間複雜度降到 \(O(1)\),這就是差分的重要作用。

把差分思想用到樹上,只需要把樹上路徑轉換為區間即可。把一條路徑 \(u \rightarrow v\) 分為兩部分:\(u \rightarrow LCA(u,v)\)\(LCA(u,v) \rightarrow v\),這樣每條路徑都可以當成一個區間處理。

\(LCA(u,v)=R\),並記 \(R\) 的父節點為 \(F=fa[R]\),要把路徑上每個節點權值加 \(1\),有:

  1. 路徑 \(u \rightarrow R\) 這個區間上,\(D[u]++\)\(D[F]--\)
  2. 路徑 \(v \rightarrow R\) 這個區間上,\(D[v]++\)\(D[F]--\)

經過以上操作,能透過 \(D\) 計算出 \(u \rightarrow v\) 上每個節點的權值。不過,由於兩條路徑在 \(R\)\(F\) 這裡重合了,這兩個步驟把 \(D[R]\) 加了兩次,把 \(D[F]\) 減了兩次,需要調整為 \(D[R]--\)\(D[F]--\)

image

在本題中,對每條路徑都用倍增法求一次 LCA,並做一次差分操作。當對於所有路徑都操作完成後,再做一次 DFS,求出每個節點的權值,所有權值中的最大值即為答案。

\(k\) 次 LCA 的時間複雜度為 \(O(n \log n + k \log n)\);最後做一次 DFS,時間複雜度為 \(O(n)\);總的時間複雜度為 \(O((n+k) \log n)\)

參考程式碼
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 50005;
const int LOG = 16;
vector<int> tree[N];
int d[N], fa[N][LOG], a[N], ans;
void dfs(int cur, int pre) {
    d[cur] = d[pre] + 1;
    fa[cur][0] = pre;
    for (int i = 1; i < LOG; i++) fa[cur][i] = fa[fa[cur][i - 1]][i - 1];
    for (int nxt : tree[cur])
        if (nxt != pre) dfs(nxt, cur);
}
int lca(int x, int y) {
    if (d[x] < d[y]) swap(x, y);
    int len = d[x] - d[y];
    for (int i = LOG - 1; i >= 0; i--) 
        if (1 << i <= len) {
            x = fa[x][i]; len -= 1 << i;
        }
    if (x == y) return x;
    for (int i = LOG - 1; i >= 0; i--) 
        if (fa[x][i] != fa[y][i]) {
            x = fa[x][i]; y = fa[y][i];
        }
    return fa[x][0];
}
void calc(int cur, int pre) {
    for (int nxt : tree[cur])
        if (nxt != pre) {
            calc(nxt, cur); 
            a[cur] += a[nxt];
        } 
    ans = max(ans, a[cur]);
}
int main()
{
    int n, k;
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        tree[x].push_back(y); tree[y].push_back(x);
    }
    dfs(1, 0); // 計算每個節點的深度並預處理fa陣列
    while (k--) {
        int s, t;
        scanf("%d%d", &s, &t);
        int r = lca(s, t);
        a[s]++; a[t]++; a[r]--; a[fa[r][0]]--; // 樹上差分
    }  
    calc(1, 0); // 用差分陣列求每個節點的權值
    printf("%d\n", ans);
    return 0;
}

邊差分

例題:P6869 [COCI2019-2020#5] Putovanje

顯然針對每一條邊只會考慮購買單程票和多程票的一種,這取決於該條邊被經過的次數 \(k\),這樣一來這條邊上的最少花費是 \(\min (k c_1, c_2)\)

這裡需要根據若干條路徑計算出每條邊經過的次數,可以藉助差分思想,注意它和點差分不同。對於邊相關的問題,一般我們會將每個點與它父親節點相連的邊與該點繫結,從而將邊上資訊的維護轉化為對點的資訊的維護

image

參考程式碼
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 200005;
const int LOG = 19;
vector<int> tree[N];
int d[N], fa[N][LOG], cnt[N], a[N], b[N], c1[N], c2[N];
void dfs(int cur, int pre) {
    d[cur] = d[pre] + 1;
    fa[cur][0] = pre;
    for (int i = 1; i < LOG; i++) fa[cur][i] = fa[fa[cur][i - 1]][i - 1];
    for (int nxt : tree[cur]) 
        if (nxt != pre) dfs(nxt, cur);
}
int lca(int x, int y) {
    if (d[x] < d[y]) swap(x, y);
    int len = d[x] - d[y];
    for (int i = LOG - 1; i >= 0; i--) 
        if ((1 << i) <= len) {
            x = fa[x][i]; len -= 1 << i;
        }
    if (x == y) return x;
    for (int i = LOG - 1; i >= 0; i--)
        if (fa[x][i] != fa[y][i]) {
            x = fa[x][i]; y = fa[y][i];
        }
    return fa[x][0];
}
void calc(int cur, int pre) {
    for (int nxt : tree[cur]) 
        if (nxt != pre) {
            calc(nxt, cur);
            cnt[cur] += cnt[nxt];
        }
}
int main()
{
    int n;
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        scanf("%d%d%d%d", &a[i], &b[i], &c1[i], &c2[i]);
        tree[a[i]].push_back(b[i]);
        tree[b[i]].push_back(a[i]);
    }
    dfs(1, 0);
    for (int i = 1; i < n; i++) {
        int r = lca(i, i + 1);
        cnt[i]++; cnt[i + 1]++; cnt[r] -= 2;
    }
    calc(1, 0);
    LL ans = 0;
    for (int i = 1; i < n; i++) {
        if (d[a[i]] > d[b[i]]) ans += min(1ll * c1[i] * cnt[a[i]], 1ll * c2[i]);
        else ans += min(1ll * c1[i] * cnt[b[i]], 1ll * c2[i]);
    }
    printf("%lld\n", ans);
    return 0;
}

相關文章