bzoj 2655: calc [容斥原理 伯努利數]

Candy?發表於2017-05-03

2655: calc

題意:長n的序列,每個數\(a_i \in [1,A]\),求所有滿足\(a_i\)互不相同的序列的\(\prod_i a_i\)的和


clj的題

一下子想到容斥,一開始從普通容斥的角度考慮,問題在於“規定兩個相同,剩下的任意選還可能出現兩個相同”

掃了一眼他的題解,發現他用\(f_i\)表示長i序列的答案。這樣的話就很科學了,規定i個相同其他任選時只會多統計i+1個的

\[ f(i) = s(1) f(i-1) - \binom{i-1}{1} s(2) f(i-2) + \binom{i-1}{2} s(3) f(i-3) -...\\ s(m) = \sum_{i=1}^A i^m \]
但這樣寫是不對的!

昨天晚上還不是很理解,今天早上來想明白了!

統計兩個相同時,每個三個相同其實多統計了\(\binom{2}{1}\)次!

對於i個相同,多統計的次數為\(\prod_{j=2}^{i-1} \binom{j}{j-1} = (i-1)!\)

所以最後的式子還要乘上個階乘

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 505;
inline int read(){
    char c=getchar(); int x=0,f=1;
    while(c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9') {x=x*10+c-'0';c=getchar();}
    return x*f;
}

int A, n, mo, inv2;
inline ll Pow(ll a, int b) {
    ll ans = 1;
    for(; b; b >>= 1, a = a * a %mo)
        if(b & 1) ans = ans * a %mo;
    return ans;
}
ll inv[N], fac[N], facInv[N], b[N], sum, s[N], ans, f[N];
inline ll C(int n, int m) {return fac[n] * facInv[m] %mo * facInv[n-m] %mo;}
void init() {
    inv[1] = fac[0] = facInv[0] = 1;
    for(int i=1; i<=n+1; i++) {
        if(i != 1) inv[i] = (mo - mo/i) * inv[mo%i] %mo;
        fac[i] = fac[i-1] * i %mo;
        facInv[i] = facInv[i-1] * inv[i] %mo;
    }
    b[0] = 1; b[1] = mo - inv2;
    for(int m=2; m<=n; m++) if(~m&1) {
        for(int k=0; k<m; k++) b[m] = (b[m] - C(m+1, k) * b[k]) %mo;
        b[m] = b[m] * Pow(m+1, mo-2) %mo;
        if(b[m] < 0) b[m] += mo;
    }
    b[1] = inv2;
    static ll mi[N];
    mi[0] = 1;
    for(int i=1; i<=n+1; i++) mi[i] = mi[i-1] * A %mo;
    for(int m=1; m<=n; m++) {
        ll t = 0;
        for(int k=0; k<=m; k++) t = (t + C(m+1, k) * b[k] %mo * mi[m+1-k] %mo) %mo;
        s[m] = t * inv[m+1] %mo;
    }
}

int main() {
    freopen("in", "r", stdin);
    A=read(); n=read(); mo=read(); inv2 = (mo+1)/2;
    init();
    f[0] = 1;
    for(int i=1; i<=n; i++) {
        ll t = s[1] * f[i-1] %mo;
        for(int j=2; j<=i; j++) t = (t + ((j&1) ? 1 : -1) * C(i-1, j-1) * fac[j-1] %mo * s[j] %mo * f[i-j] %mo ) %mo;
        f[i] = t;
    }
    printf("%lld\n", (f[n] + mo) %mo);
}

相關文章