定義
Aho-Corasick automaton,該演算法在1975年產生於貝爾實驗室,是著名的多模匹配演算法。
具體問題大致為多個模式串在一個文字串中匹配查詢的問題。
AC自動機利用某些操作阻止了模式串匹配階段的回溯,將時間複雜度優化到了 \(O(n)\)(n為文字串長度)
前置芝士
基本的 \(Trie\) 樹,\(KMP\) 的失配指標思想。
不會 \(KMP\) 可以先去學習 \(KMP\)演算法
AC 自動機
思想其實挺簡單的,就是在一棵由模式串構造出的字典樹上進行文字串的匹配。
但是如果每次匹配都從根開始匹配,演算法複雜度問題比較大,所以就有了 AC 自動機。
其思想與 \(KMP\) 一致,對於字典樹上的節點建立失配指標,匹配失敗的話直接跳轉到失配指標繼續匹配。
下邊給出一個例子:
我們現在有模式串: abc,bcd,acd,cd。
那麼我們可以建出一個這樣的字典樹
然後給出一個很長的文字串,一個一個從根開始掃複雜度非常不友好,所以我們就會用到 \(fail\) 指標。
\(fail\) 指標的建立是找到兩個串的最長公共字尾,然後連起來,所以我們這個樹的 \(fail\) 指標應該就是這樣的:
每一次匹配失敗之後,從最長的公共字尾開始匹配,這樣一定是對的。
接下來我們關心的就是如何去找 \(fail\) 指標。
首先我們建一棵字典樹,然後利用 \(BFS\) 找,首先把根節點下的點入對,然後找每次出隊的點作為父親,列舉所有可能的兒子(字母一共26個),如果有這個兒子,那麼這個兒子的 \(fail\) 指標就是父親 \(fail\) 指標的當前兒子。如果沒有當前兒子,那麼就把當前節點的當前兒子設為當前節點 \(fail\) 指標的當前兒子。(可能說的不太清楚,當前兒子指的是列舉的兒子是誰)。
下邊給出構造 \(fail\) 指標的程式碼:
queue<int>q;
inline void build(){
for(int i = 0;i < 26;++i){
if(t[0][i]) q.push(t[0][i]);
}
while(!q.empty()){
int x = q.front();
q.pop();
for(int i = 0;i < 26;++i){
if(t[x][i]){
fail[t[x][i]] = t[fail[x]][i];
q.push(t[x][i]);
}
else t[x][i] = t[fail[x]][i];
}
}
}
然後在查詢的時候,對於文字串的每個字元,就從根節點一直跳 \(fail\) 指標查詢,一直向下尋找,直到匹配失敗( \(fail\) 指標指向根或者當前節點已找過).
查詢程式碼
inline void query(char *ch){
int len = strlen(ch + 1);
int rt = 0;
for(int i = 1;i <= len;++i){
int x = ch[i] - 'a';
rt = t[rt][x];//從該字母節點開始跳
for(int j = rt;j;j = fail[j])ans[end[j]]++;//直到匹配失敗。
}
}
到這裡,AC自動機的內容就講解完了,主要就是 \(fail\) 指標的建立和查詢時跳 \(fail\) 指標的操作。
例題
下邊給出洛谷上的三道模板題以供練習:
【模板】AC自動機(簡單版)
題目
分析
板子(題目說了),查詢文字串中有多少個不同的模式串。
我們在建立字典樹的時候,對於最末尾節點用陣列 \(end\) 記錄一下有多少個串以這個點作為結尾,在查詢跳 \(fail\) 的時候,累加上跳到的點的 \(end\) 即可。
程式碼
#include<bits/stdc++.h>
using namespace std;
#define gc() getchar()
#define read() ({ register int x = 0, f = 1; register char c = gc(); while(c < '0' || c > '9') { if (c == '-') f = -1; c = gc();} while(c >= '0' && c <= '9') x = x * 10 + (c & 15), c = gc(); f * x; })
char buf[1 << 20], *p1, *p2;
const int maxn = 1e6+10;
char s[maxn];
int n;
int t[maxn][30];
int tot,end[maxn],fail[maxn];
inline void insert(char *ch){//建立字典樹
int rt = 0;
int len = strlen(ch+1);
for(int i = 1;i <= len;++i){
int x = ch[i] - 'a';
if(!t[rt][x])t[rt][x] = ++tot;
rt = t[rt][x];
}
end[rt]++;//記錄當前點是幾個單詞的結尾
}
inline void build(){//找fail指標
queue<int>q;
memset(fail,0,sizeof(fail));
for(int i = 0;i < 26;++i){
if(t[0][i])q.push(t[0][i]);
}
while(!q.empty()){
int x = q.front();
q.pop();
for(int i = 0;i < 26;++i){
if(t[x][i]){
fail[t[x][i]] = t[fail[x]][i];
q.push(t[x][i]);
}
else{
t[x][i] = t[fail[x]][i];
}
}
}
}
inline int query(char *ch){//查詢
int len = strlen(ch+1);
int rt = 0,ans = 0;
for(int i = 1;i <= len;++i){
int x = ch[i] - 'a';
rt = t[rt][x];
for(int j = rt;j && end[j] != -1;j = fail[j]){
ans += end[j];//累加
end[j] = -1;//清空
}
}
return ans;
}
int main(){
n = read();
for(int i = 1;i <= n;++i){
scanf("%s",s+1);
insert(s);
}
build();
scanf("%s",s+1);
printf("%d\n",query(s));
return 0;
}
【模板】AC自動機(加強版)
題目
分析
題目要我們找到出現最多的模式串出現的次數和是誰,那麼我們就在建立字典樹的時候,記錄一下以每個點作為結尾的串是哪個串。
在進行查詢的時候,每次跳 \(fail\) 指標跳到某個點的時候,就給當前點記錄的串的下標的答案++(有些繞,一會可以結合程式碼看看)。
最後對於 \(ans\) 陣列掃兩遍即可。
程式碼
#include<bits/stdc++.h>
using namespace std;
#define gc() getchar()
#define read() ({ register int x = 0, f = 1; register char c = gc(); while(c < '0' || c > '9') { if (c == '-') f = -1; c = gc();} while(c >= '0' && c <= '9') x = x * 10 + (c & 15), c = gc(); f * x; })
char buf[1 << 20], *p1, *p2;
const int maxn = 5e5+10;
int t[maxn][30];
char s[maxn];
char ss[200][maxn];
int tot;
int fail[maxn],end[maxn],ans[maxn];
inline void insert(char *ch,int now){
int len = strlen(ch+1);
int rt = 0;
for(int i = 1;i <= len;++i){
int x = ch[i] - 'a';
if(!t[rt][x])t[rt][x] = ++tot;
rt = t[rt][x];
}
end[rt] = now;//記錄當前點是誰的結尾
}
queue<int>q;
inline void build(){//日常建fail指標
for(int i = 0;i < 26;++i){
if(t[0][i])q.push(t[0][i]);
}
while(!q.empty()){
int x = q.front();q.pop();
for(int i = 0;i < 26;++i){
if(t[x][i]){
fail[t[x][i]] = t[fail[x]][i];
q.push(t[x][i]);
}
else{
t[x][i] = t[fail[x]][i];
}
}
}
}
inline void query(char *ch){
int len = strlen(ch + 1);
int rt = 0;
for(int i = 1;i <= len;++i){
int x = ch[i] - 'a';
rt = t[rt][x];
for(int j = rt;j;j = fail[j])ans[end[j]]++;//當前節點代表的單詞出現個數++
}
}
int main(){
int n;
while(1){
scanf("%d",&n);if(n == 0)break;
memset(t,0,sizeof(t));//多測注意清空
memset(end,0,sizeof(end));
memset(fail,0,sizeof(fail));
memset(ans,0,sizeof(ans));
for(int i = 1;i <= n;++i){
scanf("%s",ss[i]+1);
insert(ss[i],i);
}
build();
scanf("%s",s+1);
query(s);
int mx = 0;//下邊掃兩邊即可
for(int i = 1;i <= n;++i)if(ans[i] > mx)mx = ans[i];
printf("%d\n",mx);
for(int i = 1;i <= n;++i)if(ans[i] == mx)printf("%s\n",ss[i]+1);
}
return 0;
}
【模板】AC自動機(二次加強版)
題目
分析
與加強版差不了多少,只不過是查詢的東西不一樣了。
我們現在要統計每個串出現多少次,那麼我們就記錄一下每個串的終點是誰。
在查詢的時候,我們不需要跳 \(fail\) ,改成對於每個點訪問次數++,然後再從每個點的 \(fail\) 向當前點建邊。
利用差分,我們把每個點訪問次數++,就相當與把根到當前點上所有點次數都++,這樣只需要建立一棵 \(fail\) 指標連線的樹,然後 \(dfs\) 一遍求出差分答案即可。
程式碼
#include<bits/stdc++.h>
using namespace std;
inline int read(){
int x = 0, w = 1;
char ch = getchar();
for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') w = -1;
for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
return x * w;
}
const int maxn = 2e6+10;
char s[maxn];
int tot;
int fail[maxn],end[maxn],sum[maxn],t[maxn][30];
inline void insert(char *ch,int now){
int len = strlen(ch+1);
int rt = 0;
for(int i = 1;i <= len;++i){
int x = ch[i] - 'a';
if(!t[rt][x])t[rt][x] = ++tot;
rt = t[rt][x];
}
end[now] = rt;//記錄當前串結尾是哪個點
}
queue<int>q;
inline void build(){
for(int i = 0;i < 26;++i){
if(t[0][i]) q.push(t[0][i]);
}
while(!q.empty()){
int x = q.front();
q.pop();
for(int i = 0;i < 26;++i){
if(t[x][i]){
fail[t[x][i]] = t[fail[x]][i];
q.push(t[x][i]);
}
else t[x][i] = t[fail[x]][i];
}
}
}
inline void query(char *ch){
int len = strlen(ch+1);
int rt = 0;
for(int i = 1;i <= len;++i){
int x = ch[i] - 'a';
rt = t[rt][x];
sum[rt]++;//每個點訪問次數++
}
}
struct Node{
int v,next;
}e[maxn<<1];
bool vis[maxn];
int head[maxn],cnt;
inline void Add(int x,int y){
e[++cnt].v = y;
e[cnt].next = head[x];
head[x] = cnt;
}
inline void dfs(int x){
vis[x] = 1;
for(int i = head[x];i;i = e[i].next){
int v = e[i].v;
if(vis[v])continue;
dfs(v);
sum[x] += sum[v];
}
}
int main(){
int n = read();
for(int i = 1;i <= n;++i){
scanf("%s",s+1);
insert(s,i);
}
build();
scanf("%s",s+1);
query(s);
for(int i = 0;i <= tot;++i)Add(fail[i],i);//建樹
dfs(0);//差分
for(int i = 1;i <= n;++i){
printf("%d\n",sum[end[i]]);//按結尾點出現次數知道串出現次數
}
}