abc253E 相鄰元素之差不低於K的序列數

chenfy27發表於2024-03-18

給定n,m,k,找一個序列A[n],使用滿足1<=A[i]<=m,並且任意相鄰兩元素的差的絕對值大於等於k,求滿足條件的序列個數,求998244353取模。
2<=n<=1000; 1<=m<=5000; 0<=k<=m-1

設dp[i][j]表示前i個數,以j結尾的方案數,在計算dp[i+1][k]時,可以列舉j進行統計,複雜度為O(n^3),可以透過字首和最佳化成O(n^2),再用滾動陣列,將空間複雜度從O(n^2)最佳化到O(n)。注意,需要特判k=0的情況。

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=b;i>=a;i--)

template<int MOD>
struct MInt {
    int x;
    int norm(int u) const {u%=MOD; if(u<0) u+=MOD; return u;}
    MInt(int v=0):x(norm(v)) {}
    int val() const {return x;}
    MInt operator-() const {return MInt(norm(MOD-x));}
    MInt inv() const {assert(x!=0); return power(MOD-2);}
    MInt &operator*=(const MInt &o) {x=norm(x*o.x); return *this;}
    MInt &operator+=(const MInt &o) {x=norm(x+o.x); return *this;}
    MInt &operator-=(const MInt &o) {x=norm(x-o.x); return *this;}
    MInt &operator/=(const MInt &o) {*this *= o.inv(); return *this;}
    friend MInt operator*(const MInt &a, const MInt &b) {MInt ans=a; ans*=b; return ans;}
    friend MInt operator+(const MInt &a, const MInt &b) {MInt ans=a; ans+=b; return ans;}
    friend MInt operator-(const MInt &a, const MInt &b) {MInt ans=a; ans-=b; return ans;}
    friend MInt operator/(const MInt &a, const MInt &b) {MInt ans=a; ans/=b; return ans;}
    friend std::istream &operator>>(std::istream &is, MInt &a) {int u; is>>u; a=MInt(u); return is;}
    friend std::ostream &operator<<(std::ostream &os, const MInt &a) {os<<a.val(); return os;}
    MInt power(int b) const {int r=1, t=x; while(b){if(b&1) r=r*t%MOD; t=t*t%MOD; b/=2;} return MInt(r);}
};
using mint = MInt<998244353>;

const int N = 1005;
const int M = 5005;
int n, m, k;
mint dp[M], pre[M];
mint sum(int l, int r) {
    return l <= r ? pre[r] - pre[l-1] : 0;
}
void solve() {
    cin >> n >> m >> k;
    rep(j,1,m) dp[j] = 1;
    partial_sum(dp+1, dp+1+m, pre+1);
    rep(i,2,n) {
        rep(j,1,m) {
            if (k == 0)
                dp[j] = sum(1,m);
            else
                dp[j] = sum(1,j-k) + sum(j+k,m);
        }
        partial_sum(dp+1, dp+1+m, pre+1);
    }
    cout << pre[m].val() << "\n";
}

signed main() {
    cin.tie(0)->sync_with_stdio(0);
    int t = 1;
    while (t--) solve();
    return 0;
}

相關文章