前言
題目連結:Codeforces;洛谷。
題意簡述
你有一個長度為 \(n\) 的序列 \(p\) 滿足 \(p_i=i\),你可以進行 \(x\) 次操作,每次操作找到兩個不同的 \(i,j\) 並且交換 \(p_i,p_j\),問最終有幾個可能的序列。分別求出 \(x = 1, \ldots, k\) 時的答案。
\(1 \le n \le 10^9\),\(1\le k \le 200\)。
題目分析
先考慮暴力 DP。顯然原問題等價於求有多少排列經過 \(x\) 次交換後得到 \(p_i = i\)。設 \(f_i(x)\) 表示有多少長度為 \(x\) 的排列,至少經過 \(i\) 次操作可以得到原序列。邊界 \(f_0(x) = 1\)。考慮轉移。對於 \(f_i(x)\) 考慮 \(p_x\) 的值。若 \(p_x = x\),有 \(f_i(x) = f_i(x - 1)\);否則需要進行一次交換 \(f_i(x) = f_{i - 1}(x - 1)\)。綜合得到 \(f_i(x) = f_i(x - 1) + f_{i - 1}(x - 1)\)。由於我們總是能浪費偶數次操作,所以對於一個 \(x\),答案為 \(\sum f_{x - 2t}(n)\)。
這麼做時間複雜度是 \(\Theta(nk)\) 的。
顯然瓶頸在於 \(n\),考慮最佳化掉它。發現進行 \(x\) 次操作,最多隻有 \(2x\) 個關鍵點發生變化,不妨從這裡入手。
不妨列舉有 \(i\) 個位置發生了變化,類似浪費 \(2t\) 次操作,對答案的貢獻為 \(\binom{n}{i} \sum g_{x - 2t}(i)\),其中 \(g_i(x)\) 表示有多少長度為 \(x\) 的序列,每一個位置都發生了變化,至少經過 \(i\) 次操作變回原序列。
顯然 \(g\) 不等價於 \(f\),因為 \(f\) 計算的時候可能存在 \(p_i = i\)。不妨找找二者關係。我們原本可以用 \(f\) 表示出答案,現在答案和 \(g\) 有關,說明 \(f\) 也能夠由 \(g\) 表示出。發現如果我們類比統計答案,欽定只有某些位置發生變化,有:
這很二項式反演,不會的可以看看我的《學習筆記》。
根據經典定理,我們得到:
於是問題迎刃而解。時間複雜度 \(\Theta(k^3)\)。
程式碼
#include <cstdio>
#include <iostream>
using namespace std;
const int mod = 1e9 + 7;
inline int add(int a, int b) {
return a + b >= mod ? a + b - mod : a + b;
}
inline int sub(int a, int b) {
return a - b < 0 ? a - b + mod : a - b;
}
inline int mul(int a, int b) {
return 1ll * a * b % mod;
}
int n, k;
int f[420][420], g[420][420];
int frac[420], Inv[420], ifrac[420], tfrac[420];
void init(int n = 400) {
frac[0] = ifrac[0] = tfrac[0] = 1;
for (int i = 1; i <= n; ++i) {
frac[i] = mul(frac[i - 1], i);
Inv[i] = i == 1 ? 1 : sub(0, mul(mod / i, Inv[mod % i]));
ifrac[i] = mul(ifrac[i - 1], Inv[i]);
tfrac[i] = mul(tfrac[i - 1], ::n - i + 1);
}
}
inline int C(int m) { // C(n, m) = n * ... * (n - m + 1) / m!
return mul(tfrac[m], ifrac[m]);
}
inline int C(int n, int m) {
return mul(frac[n], mul(ifrac[n - m], ifrac[m]));
}
signed main() {
scanf("%d%d", &n, &k), init(), f[0][0] = 1;
for (int i = 1; i <= min(n, k << 1); ++i) {
f[i][0] = 1;
for (int j = 1; j <= k; ++j) {
f[i][j] = add(f[i - 1][j], mul(i - 1, f[i - 1][j - 1]));
}
}
for (int i = 0; i <= min(n, k << 1); ++i)
for (int j = 0; j <= k; ++j)
for (int x = 0; x <= i; ++x) {
if ((i - x) & 1)
g[i][j] = sub(g[i][j], mul(C(i, x), f[x][j]));
else
g[i][j] = add(g[i][j], mul(C(i, x), f[x][j]));
}
for (int i = 1; i <= k; ++i) {
int res = 0;
for (int j = min(i << 1, n); j >= 0; --j) {
if (i >= 2) g[j][i] = add(g[j][i], g[j][i - 2]);
res = add(res, mul(g[j][i], C(j)));
}
printf("%d ", res);
}
return 0;
}