人造情感(emotion)

CJzdc發表於2024-08-17

link

考慮 \(W(U)\) 怎麼求。

定義 \(f_x\) 表示只考慮所有在 \(x\) 子樹內的路徑時最大收益,\(sum_x\) 為只考慮 \(x\) 子樹中路徑,且欽定 \(x\) 不選的最大收益。

\(g\) 的轉移顯然:\(g_x=\sum f_{to}\)

\(f\) 轉移考慮列舉 \(\text{lca}=x\) 的所有路徑 \((u,v,w)\),有:\(f_x\longleftarrow\sum\limits_{i}(sum_i-f_i)+w\),其中 \(i\)\((u,v,w)\) 上的所有點。樹狀陣列最佳化可做到 \(O(n\log n)\)

換根,\(h_x\) 表示強制 \(x\to fa_x\) 邊不能經過,此時最大收益。

考慮怎麼從 \(h_x\) 推到 \(h_{to}\)

一種是直接強制 \(x\) 不能選,即 \(h_{to}\longleftarrow h_x-f_x+sum_x\)

一種考慮列舉一條經過 \(x\) 路徑 \((u,v,w)\),強制選該路徑。此時會對該鏈的鄰域以 \(h_x+\sum\limits_{i}(sum_i-f_i)+w\) 的收益更新。

注意到這個權值與更新的目標無關,於是可以將路徑按照收益從大往小排序。

\(x\) 子節點 \(to\) 更新可以直接暴力找到第一條 \(u,v\) 不在 \(to\) 子樹中的路徑,因為每條路徑只會被跳過 \(2\) 次,該部分為線性。

對於不為 \(x\) 子節點的 \(to\),考慮在 \(u,v\) 處分別記錄 \((u,v,w)\) 和其收益。

每次只需查詢 \(fa_{to}\) 子樹除掉 \(to\) 子樹外的最大收益,可以線段樹維護。

現在有了 \(f,sum,h\),考慮求答案。

對於一對 \(u,v\),假設其 \(\text{lca}=x\),強制選擇後總收益為 \(h_x+\sum\limits_{i}(sum_x-f_x)\)。最小的 \(w\) 即為 \(W(U)\) 減去該值,可以拆貢獻計算。

#include <bits/stdc++.h>
#define ALL(x) begin(x), end(x)
#define All(x, l, r) &x[l], &x[r] + 1
using namespace std;
void file() {
  freopen("1.in", "r", stdin);
  freopen("1.out", "w", stdout);
}
using ll = long long;
using i128 = __int128_t;
template <typename T> using vec = vector<T>;

const int mod = 998244353;

const int nLim = 3e5 + 5, kLim = 1.2e6 + 5;
int n, m, o;
array<int, nLim> siz, fa, dfn, tl, tr, dep;
array<ll, nLim> f, sum, h, pre;
array<array<int, nLim>, 20> mn;
array<vec<int>, nLim> g, id;

struct path {
  int u, v; ll w;
  path(int _u, int _v, ll _w) {
    u = _u; v = _v; w = _w;
  }
};
array<vec<path>, nLim> paths;

void dfs(int x, int Fa) {
  fa[x] = Fa; siz[x] = 1;
  dep[x] = dep[Fa] + 1;
  id[dep[x]].push_back(x);
  mn[0][tl[x] = dfn[x] = ++o] = Fa;
  for(int to : g[x])
    if(to ^ Fa) {
      dfs(to, x);
      siz[x] += siz[to];
    }
  tr[x] = o;
}

int mindfn(int x, int y) { return dfn[x] < dfn[y] ? x : y; }
void init() {
  for(int i = 1; i < 20; i++)
    for(int l = 1, r = (1 << i); r <= n; l++, r++)
      mn[i][l] = mindfn(mn[i - 1][l], mn[i - 1][l + (1 << i - 1)]);
}
int lca(int x, int y) {
  if(x == y) return x;
  if((x = dfn[x]) > (y = dfn[y])) swap(x, y);
  int p = __lg(y - x++);
  return mindfn(mn[p][x], mn[p][y - (1 << p) + 1]);
}

struct BIT {
  array<ll, nLim> tr;
  void update(int x, ll v) {
    for(; x <= n; x += (x & -x)) tr[x] += v;
  }
  ll query(int x) {
    ll res = 0;
    for(; x; x -= (x & -x)) res += tr[x];
    return res;
  }
  void update(int l, int r, ll v) {
    update(l, v); update(r + 1, -v);
  }
}bit;

void dfs2(int x, int Fa) {
  for(int to : g[x])
    if(to ^ Fa) {
      dfs2(to, x);
      sum[x] += f[to];
    }
  f[x] = sum[x];
  for(path k : paths[x])
    f[x] = max(f[x], k.w + bit.query(dfn[k.u]) + bit.query(dfn[k.v]) + sum[x]);
  bit.update(tl[x], tr[x], sum[x] - f[x]);
}

void dfs3(int x, int Fa) {
  pre[x] = pre[Fa] + sum[x] - f[x];
  for(int to : g[x])
    if(to ^ Fa) dfs3(to, x);
}

bool isanc(int x, int y) { return (tl[x] <= dfn[y]) && (tr[x] >= dfn[y]); }

#define ls (o << 1)
#define rs (o << 1 | 1)

struct SGT {
  array<ll, kLim> mx;
  void pu(int o) { mx[o] = max(mx[ls], mx[rs]); }
  void update(int o, int l, int r, int x, ll v) {
    if(l == r) return void(mx[o] = max(mx[o], v));
    int mi = (l + r) >> 1;
    (mi < x) ? update(rs, mi + 1, r, x, v) : update(ls, l, mi, x, v);
    pu(o);
  }
  ll query(int o, int l, int r, int x, int y) {
    if((l > y) || (r < x)) return 0;
    if((l >= x) && (r <= y)) return mx[o];
    int mi = (l + r) >> 1;
    return max(query(ls, l, mi, x, y), query(rs, mi + 1, r, x, y));
  }
}sgt;

int32_t main() {
  // file();
  ios::sync_with_stdio(0); cin.tie(0);
  cin >> n >> m;
  for(int i = 1, u, v; i < n; i++) {
    cin >> u >> v;
    g[u].push_back(v);
    g[v].push_back(u);
  }
  dfs(1, 0); init();
  for(int i = 1, u, v, w; i <= m; i++) {
    cin >> u >> v >> w;
    paths[lca(u, v)].emplace_back(u, v, w);
  }
  dfs2(1, 0); dfs3(1, 0); h[1] = f[1];
  for(int d = 2; d <= n; d++) {
    for(int x : id[d - 1]) {
      for(auto& k : paths[x])
        k.w += pre[k.u] + pre[k.v] - pre[x] - pre[fa[x]];
      sort(ALL(paths[x]), [&](path x, path y) { return x.w > y.w; });
      for(int to : g[x]) {
        if(fa[to] != x) continue;
        h[to] = h[x] - f[x] + sum[x];
        h[to] = max(h[to], sgt.query(1, 1, n, tl[x], tl[to] - 1));
        h[to] = max(h[to], sgt.query(1, 1, n, tr[to] + 1, tr[x]));
        for(path k : paths[x])
          if(!isanc(to, k.u) && !isanc(to, k.v)) {
            h[to] = max(h[to], h[x] + k.w);
            break;
          }
      }
    }
    for(int x : id[d - 1])
      for(path k : paths[x]) {
        sgt.update(1, 1, n, dfn[k.u], h[x] + k.w);
        sgt.update(1, 1, n, dfn[k.v], h[x] + k.w);
      }
  }
  ll all = f[1], res = (i128)all * n * n % mod;
  for(int i = 1; i <= n; i++)
    res = (res - (i128)pre[i] * 2 * n % mod + mod) % mod;
  for(int i = 1; i <= n; i++) {
    ll s = 0;
    for(int to : g[i])
      if(fa[to] == i) s += (ll)siz[to] * siz[to];
    res = (res - (i128)(h[i] - pre[i] - pre[fa[i]]) * ((ll)siz[i] * siz[i] - s) % mod + mod) % mod;
  }
  cout << res << "\n";
  return 0;
}

相關文章