D - Avoid K Palindrome

lightsong發表於2024-07-07

D - Avoid K Palindrome

https://atcoder.jp/contests/abc359/tasks/abc359_d

思路

https://atcoder.jp/contests/abc359/submissions/54822869

狀壓DP

以 K二進位制位表示 K字串(由AB組成), 判斷並記錄是否為迴文。

dp[i][j] -- 前i個字元,如果以j(k字元狀壓表示)結尾,是good string的可能字串個數。

初始化前K位dp

計算後續dp。

code

https://atcoder.jp/contests/abc359/submissions/55346351

typedef long long ll;
const ll mod = 998244353;

ll n, k;
string s;

const ll nsize = 1005;
const ll kstatesize = 1<<10 + 5;
bool mirror[kstatesize];

ll dp[nsize][kstatesize];

/*
    kseq is represented by bitseq of k bit length
    10101011111111
    this function is to check if it is of mirror structure
    mirror structure has central symmetry, for example
    101
    11011
*/
bool checkmirror(ll kseq){
    /*
    iterate from 0 positon to half position k>>1
    1111111111111111111111111111
        l                  r
    */
    for(int i=0; i<(k>>1); i++){
        ll l = i;
        ll r = k - i -1;

        ll lbit = (kseq>>l)&1;
        ll rbit = (kseq>>r)&1;

        if (lbit != rbit){
            return false;
        }
    }

    return true;
}

int main()
{
    cin >> n >> k;
    cin >> s;

    for(int i=0; i<(1<<k); i++){
        mirror[i] = checkmirror(i);
    }

    /*
    initialize the first k seq of dp
    */
    for(int i=0; i<(1<<k); i++){
        // only non-mirror seq takes effect
        if (mirror[i]){
            continue;
        }

        /*
        suppose the first k seq follow the i case
        increase dp
        */
        dp[k-1][i]++;

        /*
        then detect if any break with i case,
        if yes, decrease dp
        */
        for(int j=0; j<k; j++){
            /*
            in i case, for each bit,
            0    --    A
            1   --  B
            */
            ll jbitpos = k - j - 1;
            ll jbit = (i>>jbitpos)&1;

            if (s[j]=='A' && jbit==1){
                dp[k-1][i]--;
                break;
            }

            if (s[j]=='B' && jbit==0){
                dp[k-1][i]--;
                break;
            }
        }
//        cout<<dp[k-1][i];
    }

    /*
    now iterate from k to n-1 to calculate the following dp
    */
    for(int i=k; i<n; i++){
        /*
        iterate each non-mirror cases
        */
        for(int j=0; j<(1<<k); j++){
            // as of this case j, the previous state i-1 is not a good string
            // i.e. dp == 0
            // for the new added char of i index, the new string is still not a good string
            // so skip
            if (dp[i-1][j] == 0){
                continue;
            }

            // if the new added char of i index is not A,
            // the possible value is B or ?
            // let's make the possible state transfer
            if (s[i]!='A'){
                // newk is appended by B, and removed the left-most char
                ll newk = ((j<<1)|1)&((1<<k)-1);
//                cout<<newk;
                // if newk is not mirror, the new string is a good string
                // newk is not mirror, make state stranfer
                if (!mirror[newk]){
                    dp[i][newk] = (dp[i][newk] + dp[i-1][j]) % mod;
                }
            }

            // if the new added char of i index is not B,
            // the possible value is A or ?
            // let's make the possible state transfer
            if (s[i]!='B'){
                // newk is appended by A, and removed the left-most char
                ll newk = (j<<1)&((1<<k)-1);
//                cout<<newk;
                // if newk is not mirror, the new string is a good string
                // newk is not mirror, make state stranfer
                if (!mirror[newk]){
                    dp[i][newk] = (dp[i][newk] + dp[i-1][j]) % mod;
                }
            }
        }
//        for(int j=0;j<(1<<k);j++)cout<<dp[i][j];
//        cout<<'\n';
    }

    // now we get dp[n-1],
    // let calculate total number of all states
    ll ans = 0;
    for(int i=0; i<(1<<k); i++){
        ans = (ans + dp[n-1][i]) % mod;
//        cout<<dp[n-1][i];
    }

    cout << ans << endl;

    return 0;
}

相關文章