P6667 [清華集訓2016] 如何優雅地求和 題解

JiaY19發表於2024-11-06

一道非常有啟發性的題目。

思路

考慮對於一個給出點值的多項式函式如何處理。

我們發現,對於一個 \(m\) 次多項式 \(f(x)\),由於 \(\binom{x}{i}\)\(i\) 次多項式,所以說我們必定可以把一個多項式函式寫成如下模樣:

\[F(k)=\sum_{i=0}^m\binom{k}{i}f_i \]

可以看出,\(f_i\) 實際上是非常好得到的。

我們可以進行二項式反演。

\[\begin{align} f_k&=\sum_{i=0}^m \binom{k}{i}(-1)^{k-i}F(i)\nonumber \\ &=k!\sum_{i=0}^m \frac{F(i)}{i!}\frac{(-1)^{k-i}}{(k-i)!}\nonumber \end{align} \]

卷積處理即可。

這樣的話我們就可以使用簡單的組合數快速求出多項式的點值。

感覺這個操作還是很巧妙的,可能還比較通用。

對於這道題,剩下的部分就很簡單了,我們可以:

\[\begin{align} &=\sum_{k=0}^n\sum_{i=0}^m\binom{k}{i}f_i\binom{n}{k}x^k(1-x)^{n-k}\nonumber\\ &=\sum_{i=0}^mf_i\sum_{k=0}^n\binom{k}{i}\binom{n}{k}x^k(1-x)^{n-k}\nonumber\\ &=\sum_{i=0}^mf_i\sum_{k=0}^n\frac{n!k!}{k!(n-k)!i!(k-i)!}x^k(1-x)^{n-k}\nonumber\\ &=\sum_{i=0}^mf_i\frac{n!}{i!}\sum_{k=0}^n\frac{1}{(n-k)!(k-i)!}x^k(1-x)^{n-k}\nonumber\\ &=\sum_{i=0}^mf_i\frac{n!}{i!(n-i)!}\sum_{k=0}^n\frac{(n-i)!}{(n-k)!(k-i)!}x^k(1-x)^{n-k}\nonumber\\ &=\sum_{i=0}^mf_i\binom{n}{i}\sum_{k=0}^n\binom{n-i}{k-i}x^k(1-x)^{n-k}\nonumber\\ &=\sum_{i=0}^mf_i\binom{n}{i}\sum_{k=0}^{n-i}\binom{n-i}{k}x^{k+i}(1-x)^{n-k-i}\nonumber\\ &=\sum_{i=0}^mf_ix^i\binom{n}{i}\sum_{k=0}^{n-i}\binom{n-i}{k}x^{k}(1-x)^{n-i-k}\nonumber\\ &=\sum_{i=0}^mf_ix^i\binom{n}{i} (x+1-x)^{n-i}\nonumber\\ &=\sum_{i=0}^mf_ix^i\binom{n}{i}\nonumber\\ \end{align} \]

複雜度瓶頸在前面的處理 \(f_i\)

時間複雜度:\(O(m\log m)\)

Code

#include <bits/stdc++.h>
using namespace std;

const int mod = 998244353;
const int G = 3;
const int I = 332748118;

int n, m, x, k;
int a[20010];
int fc[20010];
int iv[20010];
int f[1 << 16];
int g[1 << 16];
int b[1 << 16];
int w[1 << 16];

inline int power(int x, int y) {
  int res = 1;
  while (y) {
    if (y & 1) res = 1ll * res * x % mod;
    x = 1ll * x * x % mod, y >>= 1;
  }
  return res;
}
inline void init(int n) {
  int x = __lg(n) + 1;
  if (k == (1 << x)) return;
  k = (1 << x);
  for (int i = 0; i < k; i++)
    b[i] = (b[i >> 1] >> 1) | ((i & 1) ? (k >> 1) : 0);
}
inline void ntt(int *f, int n, int flag) {
  init(n), w[0] = 1;
  for (int i = 0; i < k; i++) if (i < b[i]) swap(f[i], f[b[i]]);
  for (int i = 1; i < k; i <<= 1) {
    int b = i << 1;
    int w0 = power((flag ? G : I), (mod - 1) / b);
    for (int j = 1; j < i; j++) w[j] = 1ll * w[j - 1] * w0 % mod;
    for (int j = 0; j < k; j += b) {
      for (int l = 0; l < i; l++) {
        int x = f[j + l], y = 1ll * f[j + l + i] * w[l] % mod;
        f[j + l] = (x + y >= mod ? x + y - mod : x + y);
        f[j + l + i] = (x - y < 0 ? x - y + mod : x - y);
      }
    }
  }
  if (flag == 0) {
    int iv = power(k, mod - 2);
    for (int i = 0; i < k; i++) f[i] = 1ll * f[i] * iv % mod;
  }
}

int main() {
  cin >> n >> m >> x;
  for (int i = 0; i <= m; i++) cin >> a[i];
  fc[0] = 1;
  for (int i = 1; i <= m; i++) fc[i] = 1ll * fc[i - 1] * i % mod;
  iv[m] = power(fc[m], mod - 2);
  for (int i = m; i >= 1; i--) iv[i - 1] = 1ll * iv[i] * i % mod;
  for (int i = 0; i <= m; i++) {
    f[i] = 1ll * a[i] * iv[i] % mod;
    g[i] = (i & 1 ? mod - iv[i] : iv[i]);
  }
  ntt(f, m + m, 1);
  ntt(g, m + m, 1);
  for (int i = 0; i < k; i++)
    f[i] = 1ll * f[i] * g[i] % mod;
  ntt(f, m + m, 0);
  int sm = 1;
  int ns = 0;
  for (int i = 0; i <= m; i++) {
    ns = (ns + 1ll * sm * f[i]) % mod;
    sm = (1ll * sm * x) % mod;
    sm = (1ll * sm * (n - i)) % mod;
  }
  cout << ns << "\n";
}

相關文章