BD202404 110串

zouyua發表於2024-06-05

百度之星一場,t4

題目連結:

對於這種連續狀態限制的字串方案數,首先考慮dp

首先定義好每個狀態方便轉移,0狀態是結尾為0,1狀態是結尾1個連續1,2狀態是結尾兩個連續1,有以下關係

if(s[i] == '1') {
    if(j > 0) dp[i][j][0] = (dp[i][j][0] + dp[i - 1][j - 1][0] + dp[i - 1][j - 1][1]) % mod;
    dp[i][j][1] = (dp[i][j][1] + dp[i - 1][j][0]) % mod;
    dp[i][j][2] = (dp[i][j][2] + dp[i - 1][j][1] + dp[i - 1][j][2]) % mod;
} else {
    dp[i][j][0] = (dp[i][j][0] + dp[i - 1][j][0] + dp[i - 1][j][1]) % mod;
    if(j > 0) dp[i][j][1] = (dp[i][j][1] + dp[i - 1][j - 1][0]) % mod;
    if(j > 0) dp[i][j][2] = (dp[i][j][2] + dp[i - 1][j - 1][1] + dp[i - 1][j - 1][2]) % mod;
}

可以發現每個分類討論的轉移個數不加本身是有5條(轉移圖可知有6條,有一條不合法),這是基於轉移的路徑確定,主要s[i]區別就是轉移的代價變化

  • 本題卡了記憶體,故要最佳化成滾動陣列方式轉移
#include<bits/stdc++.h>
#define int long long
using namespace std;
using ull = unsigned long long;
using ll = long long;
using PII = pair<int,int>;
#define IOS ios::sync_with_stdio(false),cin.tie(0)
#define lowbit(x) (x) & (-x)
#define endl "\n" 
#define pb push_back
const int N=5e3+10;
const int INF=0x3f3f3f3f;
const int mod=998244353;
ll dp[2][N][3];
void solve()
{
    int n, k; cin >> n >> k;
    string s; cin >> s; s = " "  + s;
    dp[0][0][0] = 1;
    for(int i = 1; i <= n; i ++) {
        int u = i & 1;
        for(int j = 0; j <= k; j ++) {
            if(s[i] == '1') {
                if(j > 0) dp[u][j][0] = (dp[u][j][0] + dp[u ^ 1][j - 1][0] + dp[u ^ 1][j - 1][1]) % mod;
                dp[u][j][1] = (dp[u][j][1] + dp[u ^ 1][j][0]) % mod;
                dp[u][j][2] = (dp[u][j][2] + dp[u ^ 1][j][1] + dp[u ^ 1][j][2]) % mod;
            } else {
                dp[u][j][0] = (dp[u][j][0] + dp[u ^ 1][j][0] + dp[u ^ 1][j][1]) % mod;
                if(j > 0) dp[u][j][1] = (dp[u][j][1] + dp[u ^ 1][j - 1][0]) % mod;
                if(j > 0) dp[u][j][2] = (dp[u][j][2] + dp[u ^ 1][j - 1][1] + dp[u ^ 1][j - 1][2]) % mod;
            }
        }
        memset(dp[u ^ 1], 0, sizeof dp[u ^ 1]);
    }
    ll res = 0;
    int u = n & 1;
    for(int j = 0; j <= k; j ++) {
        res = (res + dp[u][j][0] + dp[u][j][1] + dp[u][j][2]) % mod;
    }
    cout << res << endl;
}
signed main()
{
    int T = 1;
    //cin>>T;
    while(T--)
    {
        solve();
    }
    return 0;
}