重複的串:KMP + dp 的板子題。
暴力dp
設計 \(dp_{k,i,j}\) 表示主串匹配到第 \(i\) 位,模式串有 \(j\) 位已匹配完成,目前已完成 \(k\) 次匹配的方案數。
不難寫出暴力的填表法轉移式子,發現是平方級別的轉移,所以使用填表法不行,我們嘗試刷表法。
\(dp_{k,i,j}\) 要轉移去的地方顯然是第 \(i+1\) 位,於是我們要觀察 \(j\) 裡哪些是合法的,很容易就想到用 KMP 加速這個過程,因為 KMP 做的也是匹配到某一位後,求出下一位的最大匹配長度。
假設 \(m\) 為模式串的長度。
對於 \(dp_{k,i,j}\) 我們列舉第 \(i+1\) 位的字元是什麼,依次和當前需要匹配的字元進行比較,然後用 KMP 的 next 陣列求出下一位的最大匹配長度 \(l\):
- 當最大匹配長度為 \(m\) 時,這時候 \(k\) 就要 \(+1\),所以 \(dp_{k+1,i+1,next_l} \gets dp_{k+1,i+1,next_l}+dp_{k,i,j}\)。
- 當最大匹配長度小於 \(m\) 時,直接轉移即可,\(dp_{k,i+1,l} \gets dp_{k,i+1,l}+dp_{k,i,j}\)。
時間複雜度是 \(O(k|\sum|n)\) 的,顯然不可過。
矩陣最佳化
發現當值域非常大的 \(i\) 確定時,剩下兩維的數量很小,只有 \(90\) 個值左右,考慮矩陣最佳化 dp。
我們把 \(i\) 這一位提取到前面來,作為快速冪的冪次即可。
構造矩陣有點難描述,看程式碼吧,大體就是做一遍上面的那個轉移,注意細節問題就好了。
另外,當 \(j=m\) 時,不能進行轉移。
時間複雜度 \(O((k|\sum|)^3 \log n)\)。
程式碼
#include <bits/stdc++.h>
#define fi first
#define se second
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pi;
const ll mod=998244353;
int n,ne[35],m;
char c[35];
struct mat{
ll a[105][105];
mat(){memset(a,0,sizeof(a));}
mat operator*(const mat &t)const{
mat res;
for(int i=0;i<=100;i++)
{
for(int k=0;k<=100;k++)
{
for(int j=0;j<=100;j++)
{
res.a[i][j]=(res.a[i][j]+a[i][k]*t.a[k][j])%mod;
}
}
}
return res;
}
};
void construct_dp(mat &dp)
{
for(int lv=0;lv<=2;lv++)
{
int ns=1+lv*(m+1);
for(int i=ns;i<ns+m;i++)
{
for(int j=0;j<26;j++)
{
char nc=('a'+j);
int now=i-ns;
while(now&&nc!=c[now+1])now=ne[now];
if(nc==c[now+1])now++;
if(now==m)
{
if(lv==2)continue;
now=ne[now];
dp.a[i][ns+m+1+now]=(dp.a[i][ns+m+1+now]+1)%mod;
}
else dp.a[i][ns+now]=(dp.a[i][ns+now]+1)%mod;
}
}
}
}
void construct_init(mat &s){s.a[1][1]=1;}
mat qpow(mat a,ll b)
{
mat res;
for(int i=0;i<=100;i++)res.a[i][i]=1;
while(b)
{
if(b&1)res=res*a;
a=a*a;
b>>=1;
}
return res;
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>c+1>>n;
m=strlen(c+1);
int now=0;
for(int i=2;i<=m;i++)
{
while(now&&c[now+1]!=c[i])now=ne[now];
if(c[now+1]==c[i])now++;
ne[i]=now;
}
mat dp;
construct_dp(dp);
mat s;
construct_init(s);
dp=s*qpow(dp,n);
ll ans=0;
for(int i=1+2*(m+1);i<1+3*(m+1);i++)ans=(ans+dp.a[1][i])%mod;
cout<<ans;
return 0;
}