一道非常有啟發性的題目。
思路
考慮對於一個給出點值的多項式函式如何處理。
我們發現,對於一個 \(m\) 次多項式 \(f(x)\),由於 \(\binom{x}{i}\) 為 \(i\) 次多項式,所以說我們必定可以把一個多項式函式寫成如下模樣:
\[F(k)=\sum_{i=0}^m\binom{k}{i}f_i
\]
可以看出,\(f_i\) 實際上是非常好得到的。
我們可以進行二項式反演。
\[\begin{align}
f_k&=\sum_{i=0}^m \binom{k}{i}(-1)^{k-i}F(i)\nonumber \\
&=k!\sum_{i=0}^m \frac{F(i)}{i!}\frac{(-1)^{k-i}}{(k-i)!}\nonumber
\end{align}
\]
卷積處理即可。
這樣的話我們就可以使用簡單的組合數快速求出多項式的點值。
感覺這個操作還是很巧妙的,可能還比較通用。
對於這道題,剩下的部分就很簡單了,我們可以:
\[\begin{align}
&=\sum_{k=0}^n\sum_{i=0}^m\binom{k}{i}f_i\binom{n}{k}x^k(1-x)^{n-k}\nonumber\\
&=\sum_{i=0}^mf_i\sum_{k=0}^n\binom{k}{i}\binom{n}{k}x^k(1-x)^{n-k}\nonumber\\
&=\sum_{i=0}^mf_i\sum_{k=0}^n\frac{n!k!}{k!(n-k)!i!(k-i)!}x^k(1-x)^{n-k}\nonumber\\
&=\sum_{i=0}^mf_i\frac{n!}{i!}\sum_{k=0}^n\frac{1}{(n-k)!(k-i)!}x^k(1-x)^{n-k}\nonumber\\
&=\sum_{i=0}^mf_i\frac{n!}{i!(n-i)!}\sum_{k=0}^n\frac{(n-i)!}{(n-k)!(k-i)!}x^k(1-x)^{n-k}\nonumber\\
&=\sum_{i=0}^mf_i\binom{n}{i}\sum_{k=0}^n\binom{n-i}{k-i}x^k(1-x)^{n-k}\nonumber\\
&=\sum_{i=0}^mf_i\binom{n}{i}\sum_{k=0}^{n-i}\binom{n-i}{k}x^{k+i}(1-x)^{n-k-i}\nonumber\\
&=\sum_{i=0}^mf_ix^i\binom{n}{i}\sum_{k=0}^{n-i}\binom{n-i}{k}x^{k}(1-x)^{n-i-k}\nonumber\\
&=\sum_{i=0}^mf_ix^i\binom{n}{i} (x+1-x)^{n-i}\nonumber\\
&=\sum_{i=0}^mf_ix^i\binom{n}{i}\nonumber\\
\end{align}
\]
複雜度瓶頸在前面的處理 \(f_i\)。
時間複雜度:\(O(m\log m)\)。
Code
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;
const int G = 3;
const int I = 332748118;
int n, m, x, k;
int a[20010];
int fc[20010];
int iv[20010];
int f[1 << 16];
int g[1 << 16];
int b[1 << 16];
int w[1 << 16];
inline int power(int x, int y) {
int res = 1;
while (y) {
if (y & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod, y >>= 1;
}
return res;
}
inline void init(int n) {
int x = __lg(n) + 1;
if (k == (1 << x)) return;
k = (1 << x);
for (int i = 0; i < k; i++)
b[i] = (b[i >> 1] >> 1) | ((i & 1) ? (k >> 1) : 0);
}
inline void ntt(int *f, int n, int flag) {
init(n), w[0] = 1;
for (int i = 0; i < k; i++) if (i < b[i]) swap(f[i], f[b[i]]);
for (int i = 1; i < k; i <<= 1) {
int b = i << 1;
int w0 = power((flag ? G : I), (mod - 1) / b);
for (int j = 1; j < i; j++) w[j] = 1ll * w[j - 1] * w0 % mod;
for (int j = 0; j < k; j += b) {
for (int l = 0; l < i; l++) {
int x = f[j + l], y = 1ll * f[j + l + i] * w[l] % mod;
f[j + l] = (x + y >= mod ? x + y - mod : x + y);
f[j + l + i] = (x - y < 0 ? x - y + mod : x - y);
}
}
}
if (flag == 0) {
int iv = power(k, mod - 2);
for (int i = 0; i < k; i++) f[i] = 1ll * f[i] * iv % mod;
}
}
int main() {
cin >> n >> m >> x;
for (int i = 0; i <= m; i++) cin >> a[i];
fc[0] = 1;
for (int i = 1; i <= m; i++) fc[i] = 1ll * fc[i - 1] * i % mod;
iv[m] = power(fc[m], mod - 2);
for (int i = m; i >= 1; i--) iv[i - 1] = 1ll * iv[i] * i % mod;
for (int i = 0; i <= m; i++) {
f[i] = 1ll * a[i] * iv[i] % mod;
g[i] = (i & 1 ? mod - iv[i] : iv[i]);
}
ntt(f, m + m, 1);
ntt(g, m + m, 1);
for (int i = 0; i < k; i++)
f[i] = 1ll * f[i] * g[i] % mod;
ntt(f, m + m, 0);
int sm = 1;
int ns = 0;
for (int i = 0; i <= m; i++) {
ns = (ns + 1ll * sm * f[i]) % mod;
sm = (1ll * sm * x) % mod;
sm = (1ll * sm * (n - i)) % mod;
}
cout << ns << "\n";
}