AC自動機 提高篇

zxh923發表於2024-07-30

文字生成器

首先考慮一個容斥,算出不包含任何一個單詞的文章的數量。

我們設 \(dp_{i,j}\) 表示當前文章長度為 \(i\),最後一個字元在 \(AC\) 自動機上的 \(j\) 號點的方案數。我們要求的答案就是 \(\displaystyle 26^m-\sum_{i=0}^{idx}f_{m,i}\)

於是我們考慮怎麼轉移。

首先,我們在建立 \(AC\) 自動機的時候,如果發現一個節點 \(i\) 指向的 \(ne_i\) 節點有結束標記,那麼我們把這個點也打上結束標記。其他東西沒有區別

然後我們對於 \(AC\) 自動機上的每一個點 \(j\),設他的其中一個位元組點為 \(k\)。那麼如果 \(k\) 節點沒有結束標記,則 \(f_{i,k}+f_{i-1,j}\leftarrow f_{i,k}\)。於是就做完了,程式碼:

#include<bits/stdc++.h>
#define int long long
#define N 10005
#define M 105
#define mod 10007
using namespace std;
int tr[N][26],cnt[N],ne[N],n,m,idx,f[M][N];
char s[N];
//f_{i,j}表示字串長度i,最後一個字元的點為j,且不包含任意一個模式串的方案數
int ksm(int x,int y){
	int res=1;
	while(y){
		if(y&1)(res*=x)%=mod;
		(x*=x)%=mod;
		y>>=1;
	}
	return res;
}
void ins(){
	int p=0;
	for(int i=0;s[i];i++){
		int t=s[i]-'A';
		if(tr[p][t]==0)tr[p][t]=++idx;
		p=tr[p][t];
	}
	cnt[p]=1;
}
void build(){
	queue<int>q;
	for(int i=0;i<26;i++){
		if(tr[0][i]!=0){
			q.push(tr[0][i]);
		}
	}
	while(!q.empty()){
		int t=q.front();
		q.pop();
		for(int i=0;i<26;i++){
			int c=tr[t][i];
			if(c==0){
				tr[t][i]=tr[ne[t]][i];
			}
			else{
				if(cnt[tr[ne[t]][i]]==1)cnt[tr[t][i]]=1;
				ne[c]=tr[ne[t]][i];
				q.push(c);
			}
		}
	}
}
signed main(){
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		cin>>s;
		ins();
	}
	build();
	f[0][0]=1;//空字串,方案數1
	for(int i=1;i<=m;i++){
		for(int j=0;j<=idx;j++){
			for(int k=0;k<26;k++){
				if(cnt[tr[j][k]]==1)continue;
				(f[i][tr[j][k]]+=f[i-1][j])%=mod;
			}
		}
	}
	int sum=ksm(26,m),res=0;
	for(int i=0;i<=idx;i++){
		(res+=f[m][i])%=mod;
	}
	if(sum<res)sum+=mod;
	cout<<sum-res;
	return 0;
}

數數

發現就是上一題的 \(dp\) 變成了數位 \(dp\),於是我們採用記憶化搜尋實現。

我們設 \(f_{i,j,k,l}\) 表示當前到了第 \(i\) 位數,最後一個數字在 \(AC\) 自動機上的位置為 \(j\),並且當前前 \(i\) 位是或不是和限制的數完全相同,當前前 \(i\) 位是不是全是 \(0\)

但是由於我們記搜是最後到只剩一位才返回答案,相當於是倒著算的,所以我們初始要把限制數字翻轉。

剩下的 \(dp\) 方式幾乎和上一題一樣,程式碼:

#include<bits/stdc++.h>
#define int long long
#define N 1605
#define mod 1000000007
using namespace std;
int tr[N][10],cnt[N],ne[N],n,m,idx,f[N][N][2][2];
string t;
char s[N];
void ins(string s){
	int p=0;
	for(int i=0;s[i];i++){
		int t=s[i]-'0';
		if(tr[p][t]==0)tr[p][t]=++idx;
		p=tr[p][t];
	}
	cnt[p]=1;
}
void build(){
	queue<int>q;
	for(int i=0;i<10;i++){
		if(tr[0][i]!=0){
			q.push(tr[0][i]);
		}
	}
	while(!q.empty()){
		int t=q.front();
		q.pop();
		for(int i=0;i<10;i++){
			int c=tr[t][i];
			if(c==0){
				tr[t][i]=tr[ne[t]][i];
			}
			else{
				ne[c]=tr[ne[t]][i];
				cnt[c]|=cnt[ne[c]];
				q.push(c);
			}
		}
	}
}
int dp(int dep,int ac_pos,bool is_lim,bool has_zer){
	if(dep==0)return cnt[ac_pos]==0;
	if(cnt[ac_pos]==1)return 0;
	int &v=f[dep][ac_pos][is_lim][has_zer];
	if(v!=-1)return v;
	int lim=is_lim?(s[dep]-'0'):9ll;
	int sum=0;
	for(int i=0;i<=lim;i++){
		int p1=(has_zer&&(i==0))?0:tr[ac_pos][i];
		bool f1=(is_lim&&(i+'0'==s[dep]));
		bool f2=(has_zer&&(i==0));
		(sum+=dp(dep-1,p1,f1,f2))%=mod;
	}
	return v=sum;
}
signed main(){
	cin>>s+1>>n;
	int len=strlen(s+1);
	reverse(s+1,s+len+1);
	memset(f,-1,sizeof f);
	for(int i=1;i<=n;i++){
		cin>>t;
		ins(t);
	}
	build();
	int res=dp(len,0,1,1)+mod-1;
	if(res>mod)res-=mod;
	cout<<res;
	return 0;
}

相關文章