從零開始發明 AC 自動機

August_Light發表於2024-03-09

AC 自動機是一種多模字串匹配演算法。

[Luogu P5357]【模板】AC 自動機

給你一個文字串 $S$ 和 $n$ 個模式串 $T_{1 \sim n}$,請你分別求出每個模式串 $T_i$ 在 $S$ 中出現的次數。

$1 \le n \le 2 \times {10}^5$,$T_{1 \sim n}$ 的長度總和不超過 $2 \times {10}^5$,$S$ 的長度不超過 $2 \times {10}^6$。

下文中涉及時間複雜度的部分,$n$ 為模式串長度之和,$m$ 為文字串長度。

前置知識

  • 字典樹:[Luogu P8306]【模板】字典樹
  • KMP:[Luogu P3375]【模板】KMP(不必要,但是最好了解其思想)。
  • 自動機(DFA)基本概念:https://oi-wiki.org/string/automaton/
  • 自動機五要素:
    • 字符集 $\Sigma$。
    • 狀態集合 $Q$。
    • 起始狀態 $start$。
    • 接收狀態集合 $F$。
    • 轉移函式 $\delta$。$\delta(u,c)$ 中 $u,\delta(u,c) \in Q$,$c \in \Sigma$。

Step 1:AC 自動機基於字典樹

有多個模式串,考慮有什麼簡單的結構能解決多個字串的問題。不難想到雜湊和字典樹。

雜湊可能會碰撞,且看起來跟自動機相關理論沒什麼關係,很難擴充套件。

字典樹可以視作自動機。

建立字典樹時間複雜度 $O(n)$。

int ins(string s) {
    int u = 0;
    for (auto ch : s) {
        int c = ch - 'a';
        if (!tr[u][c])
            tr[u][c] = ++tot;
        u = tr[u][c];
    }
    return u;
}

Step 2:fail 陣列的定義

多模字串匹配是單個模式串匹配的擴充套件,所以考慮 KMP。

KMP 演算法可以視作自動機。基於字串 $s$ 的 KMP 自動機接受且僅接受以 $s$ 為字尾的字串。

那麼 AC 自動機就應該是:基於字串 $s_{1 \sim n}$ 的 AC 自動機接受且僅接受以 $s_{1 \sim n}$ 任意一個為字尾的字串。

考慮在 Trie 上定義一個類似 KMP 中 next 陣列的陣列。

具體地,定義 $fail(u)$ 為 $u$ 表示的字串 最長的出現在 Trie 上的 字尾對應的狀態。

在自動機上連上 $u$ 與 $fail(u)$ 的邊,這條邊被稱為 fail 邊。

從 OI-Wiki 偷一張圖來解釋:

灰色邊為 Trie,黃色邊為 fail 邊。

例如此圖中 $9$ 號連到 $2$ 號,是因為 $\texttt{she}$ 出現在 Trie 上的最長真字尾為 $\texttt{he}$,即 $2$ 號。

不難發現,這個 $fail(u)$ 當 Trie 中只有一個模式串時,就是 KMP 的 next 陣列(這裡的 next 陣列表示 border 長度)。

重要性質:fail 邊形成一棵樹。這是 KMP 的 fail 樹的應用:[Luogu P5829]【模板】失配樹

Step 3:fail 如何求 & 構建 AC 自動機

自動機五要素:

  • 字符集 $\Sigma$,為小寫字母。
  • 狀態集合 $Q$,為 Trie 上的所有節點。
  • 起始狀態 $start$,為 Trie 的根節點 $0$。
  • 接收狀態集合 $F$,為所有模式串在 Trie 上的節點。
  • 轉移函式 $\delta$,下文著重講解這一點。

以下 $tr$ 指原字典樹。

若 $tr_{u,c}$ 存在,則 $\delta(u,c) = tr_{u,c}$,$fail(\delta(u,c)) = \delta(fail(u),c)$。

  • 注意到 $fail(\delta(u,c))$ 基於 $fail(u)$,所以我們 BFS 求解 fail。

若 $tr_{u,c}$ 不存在:

  • 若 $u$ 是根節點 $0$,則 $\delta(u,c) = 0$。
    • 如果沒有這一條,則 $0$ 的兒子的 $fail$ 會連到自身,不滿足真字尾。
  • 否則 $\delta(u,c) = \delta(fail(u),c)$。

最後一條的遞迴與 KMP 的不斷跳 next 是相同的。關於這一點,我們可以看看 KMP 自動機的 $\delta$:

$$\delta(u,c) = \begin{cases} u+1 & c = s_{u+1} \ 0 & c \ne s_{u+1} \land u = 0 \ \delta(next(u),c) & c \ne s_{u+1} \land u \ne 0 \end{cases}$$

再看看 AC 自動機的 $\delta$:

$$\delta(u,c) = \begin{cases} tr_{u,c} & tr_{u,c} \text{ exists} \ 0 & tr_{u,c} \text{ does not exist} \land u = 0 \ \delta(fail(u),c) & tr_{u,c} \text{ does not exist} \land u \ne 0 \end{cases}$$

不能說十分類似,只能說是一模一樣。

程式碼上,我們不用重新建一個自動機,直接按照 AC 自動機的 $\delta$ 改 Trie 的結構即可。時間複雜度 $O(n |\Sigma|)$。

// tr 原本為字典樹
void bfs() {
    queue<int> q;
    for (int c = 0; c < 26; c++)
        if (tr[0][c])
            q.push(tr[0][c]);
    while (!q.empty()) {
        int u = q.front(); q.pop();
        for (int c = 0; c < 26; c++)
            if (tr[u][c]) {
                fail[tr[u][c]] = tr[fail[u]][c];
                q.push(tr[u][c]);
            } else
                tr[u][c] = tr[fail[u]][c];
    }
}

你非要新建一個自動機也不是不行。但是空間常數大一倍,沒啥意義。

// tr 為字典樹
// dt 指轉移函式 delta
void bfs() {
    queue<int> q;
    for (int c = 0; c < 26; c++)
        if (tr[0][c])
            q.push(dt[0][c] = tr[0][c]);
    while (!q.empty()) {
        int u = q.front(); q.pop();
        for (int c = 0; c < 26; c++)
            if (tr[u][c]) {
                dt[u][c] = tr[u][c];
                fail[dt[u][c]] = dt[fail[u]][c];
                q.push(dt[u][c]);
            } else
                dt[u][c] = dt[fail[u]][c];
    }
}
// 注意後文作匹配的時候要沿著 dt 而不是 tr

Step 4:多模字串匹配

接下來我們就可以把文字串作為輸入給到 AC 自動機。

用一個陣列記錄每一個節點被走過了多少次。

建出 fail 樹,DFS 子樹求和,儲存在 $sum$ 陣列。

此時 $sum_u$ 為 $u$ 對應的字串被匹配到的次數。原因是 fail 樹上,若一節點匹配上了,則其祖先也必然匹配。

第 $i$ 個模式串對應節點的子樹和即為答案。

(這一段看具體程式碼更容易懂。)

總結 & 完整程式碼

  1. 建出 Trie 樹,儲存每個模式串在 Trie 上的位置。$O(n)$。
  2. 把 Trie 樹改造為 AC 自動機,並求出 fail 陣列,建出 fail 樹。$O(n |\Sigma|)$。
  3. 把文字串作為輸入給到 AC 自動機,在 fail 樹上求和得到答案。$O(m)$。

空間複雜度為 $O(n |\Sigma| + m)$。

DFS 用了 lambda 表示式。以普通函式的形式寫一個 DFS 也是沒有問題的。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int MAXN = 2e5 + 5; // 模式串長度之和

int tr[MAXN][26], fail[MAXN], tot = 0;
int e[MAXN], sum[MAXN];
vector<int> G[MAXN];
int ins(string s) {
    int u = 0;
    for (auto ch : s) {
        int c = ch - 'a';
        if (!tr[u][c])
            tr[u][c] = ++tot;
        u = tr[u][c];
    }
    return u;
}
void bfs() {
    queue<int> q;
    for (int c = 0; c < 26; c++)
        if (tr[0][c])
            q.push(tr[0][c]);
    while (!q.empty()) {
        int u = q.front(); q.pop();
        for (int c = 0; c < 26; c++)
            if (tr[u][c]) {
                fail[tr[u][c]] = tr[fail[u]][c];
                q.push(tr[u][c]);
            } else
                tr[u][c] = tr[fail[u]][c];
    }
}

int main() { ios::sync_with_stdio(0); cin.tie(0);
    int n; cin >> n; for (int i = 1; i <= n; i++) {
        string s; cin >> s;
        e[i] = ins(s);
    }
    bfs();
    for (int u = 1; u <= tot; u++)
        G[fail[u]].push_back(u);

    string t; cin >> t;
    int u = 0;
    for (auto ch : t) {
        int c = ch - 'a';
        u = tr[u][c];
        sum[u]++;
    }
    auto dfs = [&](int u, auto&& self) -> void {
        for (auto v : G[u]) {
            self(v, self);
            sum[u] += sum[v];
        }
    };
    dfs(0, dfs);
    for (int i = 1; i <= n; i++)
        cout << sum[e[i]] << '\n';
    return 0;
}

相關文章