NOIP2024 加賽 8

ccxswl發表於2024-11-27

騙你的,沒寫。

不過這場分比較高,前三道切的都挺順,T4 也拿了暴力分。

T3 和題解的處理辦法不太一樣,具體就是沒有統計每條邊的貢獻,樹上 DP 求的是子樹內的答案,處理修改的時候也不一樣。

就掛個程式碼吧。

#include <bits/stdc++.h>

using namespace std;

#define int long long
using ubt = long long;
using uubt = unsigned long long;
#define vec vector
#define eb emplace_back
#define bg begin
#define emp emplace
#define mkp make_pair
#define fi first
#define se second
using pii = pair<int, int>;
const int inf = 1e9;

const int maxN = 5e5 + 7;
const int mod = 998244353;
const int I = 499122177;

int n, m;

bool is[maxN];
ubt num[maxN], tot[maxN], f[maxN];

vec<pii> g[maxN];

int snum[maxN], snum2[maxN];

void dfs(int x, int fa) {
  num[x] = is[x];
  f[x] = tot[x] = 0;
  for (auto &[to, w] : g[x]) {
    if (to == fa) continue;
    dfs(to, x);
    f[x] += f[to] + num[x] * tot[to] % mod + w * num[x] * num[to] % mod + num[to] * tot[x] % mod;
    f[x] %= mod;
    num[x] += num[to];
    tot[x] += tot[to] + w * num[to] % mod;
    tot[x] %= mod;

    snum[x] += num[to];
    snum2[x] += num[to] * num[to] % mod;
    snum2[x] %= mod;
  }
}

int ans;

signed main() {
  freopen("sakuya.in", "r", stdin);
  freopen("sakuya.out", "w", stdout);

  cin.tie(nullptr)->sync_with_stdio(false);

  cin >> n >> m;
  for (int i = 1; i < n; i++) {
    int u, v, w;
    cin >> u >> v >> w;
    g[u].eb(v, w);
    g[v].eb(u, w);
  }
  for (int i = 1; i <= m; i++) {
    int a;
    cin >> a;
    is[a] = true;
  }
  dfs(1, 0);
  ans = f[1];

  auto ksm = [](int a, int b) {
    int res = 1;
    while (b) {
      if (b & 1) res = res * a % mod;
      a = a * a % mod;
      b >>= 1;
    }
    return res;
  };
  auto M = ksm(m, mod - 2) * 2 % mod;

  int Q;
  cin >> Q;
  while (Q--) {
    //cerr << '\n';
    int x, k;
    cin >> x >> k;
    
    ans += (snum[x] * snum[x] % mod - snum2[x] + mod) * k % mod;
    ans %= mod;
    if (is[x]) {
      ans += snum[x] * k % mod;
      ans %= mod;
      ans += (m - num[x]) * k % mod;
      ans %= mod;
    }
    ans += snum[x] * (m - num[x]) % mod * k % mod * 2 % mod;
    ans %= mod;

    if (ans < 0) ans += mod;
    cout << ans * M % mod << '\n';
  }
}

相關文章