link
考慮 \(W(U)\) 怎麼求。
定義 \(f_x\) 表示只考慮所有在 \(x\) 子樹內的路徑時最大收益,\(sum_x\) 為只考慮 \(x\) 子樹中路徑,且欽定 \(x\) 不選的最大收益。
\(g\) 的轉移顯然:\(g_x=\sum f_{to}\)
\(f\) 轉移考慮列舉 \(\text{lca}=x\) 的所有路徑 \((u,v,w)\),有:\(f_x\longleftarrow\sum\limits_{i}(sum_i-f_i)+w\),其中 \(i\) 為 \((u,v,w)\) 上的所有點。樹狀陣列最佳化可做到 \(O(n\log n)\)。
換根,\(h_x\) 表示強制 \(x\to fa_x\) 邊不能經過,此時最大收益。
考慮怎麼從 \(h_x\) 推到 \(h_{to}\)。
一種是直接強制 \(x\) 不能選,即 \(h_{to}\longleftarrow h_x-f_x+sum_x\);
一種考慮列舉一條經過 \(x\) 路徑 \((u,v,w)\),強制選該路徑。此時會對該鏈的鄰域以 \(h_x+\sum\limits_{i}(sum_i-f_i)+w\) 的收益更新。
注意到這個權值與更新的目標無關,於是可以將路徑按照收益從大往小排序。
\(x\) 子節點 \(to\) 更新可以直接暴力找到第一條 \(u,v\) 不在 \(to\) 子樹中的路徑,因為每條路徑只會被跳過 \(2\) 次,該部分為線性。
對於不為 \(x\) 子節點的 \(to\),考慮在 \(u,v\) 處分別記錄 \((u,v,w)\) 和其收益。
每次只需查詢 \(fa_{to}\) 子樹除掉 \(to\) 子樹外的最大收益,可以線段樹維護。
現在有了 \(f,sum,h\),考慮求答案。
對於一對 \(u,v\),假設其 \(\text{lca}=x\),強制選擇後總收益為 \(h_x+\sum\limits_{i}(sum_x-f_x)\)。最小的 \(w\) 即為 \(W(U)\) 減去該值,可以拆貢獻計算。
#include <bits/stdc++.h>
#define ALL(x) begin(x), end(x)
#define All(x, l, r) &x[l], &x[r] + 1
using namespace std;
void file() {
freopen("1.in", "r", stdin);
freopen("1.out", "w", stdout);
}
using ll = long long;
using i128 = __int128_t;
template <typename T> using vec = vector<T>;
const int mod = 998244353;
const int nLim = 3e5 + 5, kLim = 1.2e6 + 5;
int n, m, o;
array<int, nLim> siz, fa, dfn, tl, tr, dep;
array<ll, nLim> f, sum, h, pre;
array<array<int, nLim>, 20> mn;
array<vec<int>, nLim> g, id;
struct path {
int u, v; ll w;
path(int _u, int _v, ll _w) {
u = _u; v = _v; w = _w;
}
};
array<vec<path>, nLim> paths;
void dfs(int x, int Fa) {
fa[x] = Fa; siz[x] = 1;
dep[x] = dep[Fa] + 1;
id[dep[x]].push_back(x);
mn[0][tl[x] = dfn[x] = ++o] = Fa;
for(int to : g[x])
if(to ^ Fa) {
dfs(to, x);
siz[x] += siz[to];
}
tr[x] = o;
}
int mindfn(int x, int y) { return dfn[x] < dfn[y] ? x : y; }
void init() {
for(int i = 1; i < 20; i++)
for(int l = 1, r = (1 << i); r <= n; l++, r++)
mn[i][l] = mindfn(mn[i - 1][l], mn[i - 1][l + (1 << i - 1)]);
}
int lca(int x, int y) {
if(x == y) return x;
if((x = dfn[x]) > (y = dfn[y])) swap(x, y);
int p = __lg(y - x++);
return mindfn(mn[p][x], mn[p][y - (1 << p) + 1]);
}
struct BIT {
array<ll, nLim> tr;
void update(int x, ll v) {
for(; x <= n; x += (x & -x)) tr[x] += v;
}
ll query(int x) {
ll res = 0;
for(; x; x -= (x & -x)) res += tr[x];
return res;
}
void update(int l, int r, ll v) {
update(l, v); update(r + 1, -v);
}
}bit;
void dfs2(int x, int Fa) {
for(int to : g[x])
if(to ^ Fa) {
dfs2(to, x);
sum[x] += f[to];
}
f[x] = sum[x];
for(path k : paths[x])
f[x] = max(f[x], k.w + bit.query(dfn[k.u]) + bit.query(dfn[k.v]) + sum[x]);
bit.update(tl[x], tr[x], sum[x] - f[x]);
}
void dfs3(int x, int Fa) {
pre[x] = pre[Fa] + sum[x] - f[x];
for(int to : g[x])
if(to ^ Fa) dfs3(to, x);
}
bool isanc(int x, int y) { return (tl[x] <= dfn[y]) && (tr[x] >= dfn[y]); }
#define ls (o << 1)
#define rs (o << 1 | 1)
struct SGT {
array<ll, kLim> mx;
void pu(int o) { mx[o] = max(mx[ls], mx[rs]); }
void update(int o, int l, int r, int x, ll v) {
if(l == r) return void(mx[o] = max(mx[o], v));
int mi = (l + r) >> 1;
(mi < x) ? update(rs, mi + 1, r, x, v) : update(ls, l, mi, x, v);
pu(o);
}
ll query(int o, int l, int r, int x, int y) {
if((l > y) || (r < x)) return 0;
if((l >= x) && (r <= y)) return mx[o];
int mi = (l + r) >> 1;
return max(query(ls, l, mi, x, y), query(rs, mi + 1, r, x, y));
}
}sgt;
int32_t main() {
// file();
ios::sync_with_stdio(0); cin.tie(0);
cin >> n >> m;
for(int i = 1, u, v; i < n; i++) {
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0); init();
for(int i = 1, u, v, w; i <= m; i++) {
cin >> u >> v >> w;
paths[lca(u, v)].emplace_back(u, v, w);
}
dfs2(1, 0); dfs3(1, 0); h[1] = f[1];
for(int d = 2; d <= n; d++) {
for(int x : id[d - 1]) {
for(auto& k : paths[x])
k.w += pre[k.u] + pre[k.v] - pre[x] - pre[fa[x]];
sort(ALL(paths[x]), [&](path x, path y) { return x.w > y.w; });
for(int to : g[x]) {
if(fa[to] != x) continue;
h[to] = h[x] - f[x] + sum[x];
h[to] = max(h[to], sgt.query(1, 1, n, tl[x], tl[to] - 1));
h[to] = max(h[to], sgt.query(1, 1, n, tr[to] + 1, tr[x]));
for(path k : paths[x])
if(!isanc(to, k.u) && !isanc(to, k.v)) {
h[to] = max(h[to], h[x] + k.w);
break;
}
}
}
for(int x : id[d - 1])
for(path k : paths[x]) {
sgt.update(1, 1, n, dfn[k.u], h[x] + k.w);
sgt.update(1, 1, n, dfn[k.v], h[x] + k.w);
}
}
ll all = f[1], res = (i128)all * n * n % mod;
for(int i = 1; i <= n; i++)
res = (res - (i128)pre[i] * 2 * n % mod + mod) % mod;
for(int i = 1; i <= n; i++) {
ll s = 0;
for(int to : g[i])
if(fa[to] == i) s += (ll)siz[to] * siz[to];
res = (res - (i128)(h[i] - pre[i] - pre[fa[i]]) * ((ll)siz[i] * siz[i] - s) % mod + mod) % mod;
}
cout << res << "\n";
return 0;
}