AC自動機學習筆記

Vocanda發表於2020-10-25

定義

Aho-Corasick automaton,該演算法在1975年產生於貝爾實驗室,是著名的多模匹配演算法。
具體問題大致為多個模式串在一個文字串中匹配查詢的問題。
AC自動機利用某些操作阻止了模式串匹配階段的回溯,將時間複雜度優化到了 \(O(n)\)(n為文字串長度)

前置芝士

基本的 \(Trie\) 樹,\(KMP\) 的失配指標思想。
不會 \(KMP\) 可以先去學習 \(KMP\)演算法

AC 自動機

思想其實挺簡單的,就是在一棵由模式串構造出的字典樹上進行文字串的匹配。
但是如果每次匹配都從根開始匹配,演算法複雜度問題比較大,所以就有了 AC 自動機。
其思想與 \(KMP\) 一致,對於字典樹上的節點建立失配指標,匹配失敗的話直接跳轉到失配指標繼續匹配。
下邊給出一個例子:
我們現在有模式串: abc,bcd,acd,cd。
那麼我們可以建出一個這樣的字典樹

然後給出一個很長的文字串,一個一個從根開始掃複雜度非常不友好,所以我們就會用到 \(fail\) 指標。
\(fail\) 指標的建立是找到兩個串的最長公共字尾,然後連起來,所以我們這個樹的 \(fail\) 指標應該就是這樣的:

每一次匹配失敗之後,從最長的公共字尾開始匹配,這樣一定是對的。

接下來我們關心的就是如何去找 \(fail\) 指標。
首先我們建一棵字典樹,然後利用 \(BFS\) 找,首先把根節點下的點入對,然後找每次出隊的點作為父親,列舉所有可能的兒子(字母一共26個),如果有這個兒子,那麼這個兒子的 \(fail\) 指標就是父親 \(fail\) 指標的當前兒子。如果沒有當前兒子,那麼就把當前節點的當前兒子設為當前節點 \(fail\) 指標的當前兒子。(可能說的不太清楚,當前兒子指的是列舉的兒子是誰)。

下邊給出構造 \(fail\) 指標的程式碼:

queue<int>q;
inline void build(){
	for(int i = 0;i < 26;++i){
		if(t[0][i]) q.push(t[0][i]);
	}
	while(!q.empty()){
		int x = q.front();
		q.pop();
		for(int i = 0;i < 26;++i){
			if(t[x][i]){
				fail[t[x][i]] = t[fail[x]][i];
				q.push(t[x][i]);
			}
			else t[x][i] = t[fail[x]][i];
		}
	}
}

然後在查詢的時候,對於文字串的每個字元,就從根節點一直跳 \(fail\) 指標查詢,一直向下尋找,直到匹配失敗( \(fail\) 指標指向根或者當前節點已找過).

查詢程式碼

inline void query(char *ch){
	int len = strlen(ch + 1);
	int rt = 0;
	for(int i = 1;i <= len;++i){
		int x = ch[i] - 'a';
		rt = t[rt][x];//從該字母節點開始跳
		for(int j = rt;j;j = fail[j])ans[end[j]]++;//直到匹配失敗。
	}
}

到這裡,AC自動機的內容就講解完了,主要就是 \(fail\) 指標的建立和查詢時跳 \(fail\) 指標的操作。

例題

下邊給出洛谷上的三道模板題以供練習:

【模板】AC自動機(簡單版)

題目

題目連結

分析

板子(題目說了),查詢文字串中有多少個不同的模式串。
我們在建立字典樹的時候,對於最末尾節點用陣列 \(end\) 記錄一下有多少個串以這個點作為結尾,在查詢跳 \(fail\) 的時候,累加上跳到的點的 \(end\) 即可。

程式碼

#include<bits/stdc++.h>
using namespace std;
#define gc() getchar()
#define read() ({ register int x = 0, f = 1; register char c = gc(); while(c < '0' || c > '9') { if (c == '-') f = -1; c = gc();} while(c >= '0' && c <= '9') x = x * 10 + (c & 15), c = gc(); f * x; })
char buf[1 << 20], *p1, *p2;
const int maxn = 1e6+10;
char s[maxn];
int n;
int t[maxn][30];
int tot,end[maxn],fail[maxn];
inline void insert(char *ch){//建立字典樹
	int rt = 0;
	int len = strlen(ch+1);
	for(int i = 1;i <= len;++i){
		int x = ch[i] - 'a';
		if(!t[rt][x])t[rt][x] = ++tot;
		rt = t[rt][x];
	}
	end[rt]++;//記錄當前點是幾個單詞的結尾
}
inline void build(){//找fail指標
	queue<int>q;
	memset(fail,0,sizeof(fail));
	for(int i = 0;i < 26;++i){
		if(t[0][i])q.push(t[0][i]);
	}
	while(!q.empty()){
		int x = q.front();
		q.pop();
		for(int i = 0;i < 26;++i){
			if(t[x][i]){
				fail[t[x][i]] = t[fail[x]][i];
				q.push(t[x][i]);
			}
			else{
				t[x][i] = t[fail[x]][i];
			}
		}
	}
}
inline int query(char *ch){//查詢
	int len = strlen(ch+1);
	int rt = 0,ans = 0;
	for(int i = 1;i <= len;++i){
		int x = ch[i] - 'a';
		rt = t[rt][x];
		for(int j = rt;j && end[j] != -1;j = fail[j]){
			ans += end[j];//累加
			end[j] = -1;//清空
		}
	}
	return ans;
}
int main(){
	n = read();
	for(int i = 1;i <= n;++i){
		scanf("%s",s+1);
		insert(s);
	}
	build();
	scanf("%s",s+1);
	printf("%d\n",query(s));
	return 0;
}

【模板】AC自動機(加強版)

題目

題目連結

分析

題目要我們找到出現最多的模式串出現的次數和是誰,那麼我們就在建立字典樹的時候,記錄一下以每個點作為結尾的串是哪個串。
在進行查詢的時候,每次跳 \(fail\) 指標跳到某個點的時候,就給當前點記錄的串的下標的答案++(有些繞,一會可以結合程式碼看看)。
最後對於 \(ans\) 陣列掃兩遍即可。

程式碼

#include<bits/stdc++.h>
using namespace std;
#define gc() getchar()
#define read() ({ register int x = 0, f = 1; register char c = gc(); while(c < '0' || c > '9') { if (c == '-') f = -1; c = gc();} while(c >= '0' && c <= '9') x = x * 10 + (c & 15), c = gc(); f * x; })
char buf[1 << 20], *p1, *p2;
const int maxn = 5e5+10;
int t[maxn][30];
char s[maxn];
char ss[200][maxn];
int tot;
int fail[maxn],end[maxn],ans[maxn];
inline void insert(char *ch,int now){
	int len = strlen(ch+1);
	int rt = 0;
	for(int i = 1;i <= len;++i){
		int x = ch[i] - 'a';
		if(!t[rt][x])t[rt][x] = ++tot;
		rt = t[rt][x];
	}
	end[rt] = now;//記錄當前點是誰的結尾
}
queue<int>q;
inline void build(){//日常建fail指標
	for(int i = 0;i < 26;++i){
		if(t[0][i])q.push(t[0][i]);
	}
	while(!q.empty()){
		int x = q.front();q.pop();
		for(int i = 0;i < 26;++i){
			if(t[x][i]){
				fail[t[x][i]] = t[fail[x]][i];
				q.push(t[x][i]);
			}
			else{
				t[x][i] = t[fail[x]][i];
			}
		}
	}
}
inline void query(char *ch){
	int len = strlen(ch + 1);
	int rt = 0;
	for(int i = 1;i <= len;++i){
		int x = ch[i] - 'a';
		rt = t[rt][x];
		for(int j = rt;j;j = fail[j])ans[end[j]]++;//當前節點代表的單詞出現個數++
	}
}
int main(){
	int n;
	while(1){
		scanf("%d",&n);if(n == 0)break;
		memset(t,0,sizeof(t));//多測注意清空
		memset(end,0,sizeof(end));
		memset(fail,0,sizeof(fail));
		memset(ans,0,sizeof(ans));
		for(int i = 1;i <= n;++i){
			scanf("%s",ss[i]+1);
			insert(ss[i],i);
		}
		build();
		scanf("%s",s+1);
		query(s);
		int mx = 0;//下邊掃兩邊即可
		for(int i = 1;i <= n;++i)if(ans[i] > mx)mx = ans[i];
		printf("%d\n",mx);
		for(int i = 1;i <= n;++i)if(ans[i] == mx)printf("%s\n",ss[i]+1);
	}
	return 0;
}

【模板】AC自動機(二次加強版)

題目

題目連結

分析

與加強版差不了多少,只不過是查詢的東西不一樣了。
我們現在要統計每個串出現多少次,那麼我們就記錄一下每個串的終點是誰。
在查詢的時候,我們不需要跳 \(fail\) ,改成對於每個點訪問次數++,然後再從每個點的 \(fail\) 向當前點建邊。
利用差分,我們把每個點訪問次數++,就相當與把根到當前點上所有點次數都++,這樣只需要建立一棵 \(fail\) 指標連線的樹,然後 \(dfs\) 一遍求出差分答案即可。

程式碼

#include<bits/stdc++.h>
using namespace std;
inline int read(){
	int x = 0, w = 1;
	char ch = getchar();
	for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') w = -1;
	for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
	return x * w;
}
const int maxn = 2e6+10;
char s[maxn];
int tot;
int fail[maxn],end[maxn],sum[maxn],t[maxn][30];
inline void insert(char *ch,int now){
	int	len = strlen(ch+1);
	int rt = 0;
	for(int i = 1;i <= len;++i){
		int x = ch[i] - 'a';
		if(!t[rt][x])t[rt][x] = ++tot;
		rt = t[rt][x];
	}
	end[now] = rt;//記錄當前串結尾是哪個點
}
queue<int>q;
inline void build(){
	for(int i = 0;i < 26;++i){
		if(t[0][i]) q.push(t[0][i]);
	}
	while(!q.empty()){
		int x = q.front();
		q.pop();
		for(int i = 0;i < 26;++i){
			if(t[x][i]){
				fail[t[x][i]] = t[fail[x]][i];
				q.push(t[x][i]);
			}
			else t[x][i] = t[fail[x]][i];
		}
	}
}
inline void query(char *ch){
	int len = strlen(ch+1);
	int rt = 0;
	for(int i = 1;i <= len;++i){
		int x = ch[i] - 'a';
		rt = t[rt][x];
		sum[rt]++;//每個點訪問次數++
	}
}
struct Node{
	int v,next;
}e[maxn<<1];
bool vis[maxn];
int head[maxn],cnt;
inline void Add(int x,int y){
	e[++cnt].v = y;
	e[cnt].next = head[x];
	head[x] = cnt;
}
inline void dfs(int x){
	vis[x] = 1;
	for(int i = head[x];i;i = e[i].next){
		int v = e[i].v;
		if(vis[v])continue;
		dfs(v);
		sum[x] += sum[v];
	}
}
int main(){
	int n = read();
	for(int i = 1;i <= n;++i){
		scanf("%s",s+1);
		insert(s,i);
	}
	build();
	scanf("%s",s+1);
	query(s);
	for(int i = 0;i <= tot;++i)Add(fail[i],i);//建樹
	dfs(0);//差分
	for(int i = 1;i <= n;++i){
		printf("%d\n",sum[end[i]]);//按結尾點出現次數知道串出現次數
	}
}

相關文章