字首樹

n1ce2cv發表於2024-09-26

字首樹

在電腦科學中,trie,又稱字首樹字典樹,是一種有序樹,用於儲存關聯陣列,其中的鍵通常是字串。與二叉查詢樹不同,鍵不是直接儲存在節點中,而是由節點在樹中的位置決定。一個節點的所有子孫都有相同的字首,也就是這個節點對應的字串,而根節點對應空字串。一般情況下,不是所有的節點都有對應的值,只有葉子節點和部分內部節點所對應的鍵才有相關的值。

  • 根據字首資訊選擇樹上的分支,可以節省大量時間。但比較浪費空間。
// 初始化字首樹物件
void Trie();
// 將字串 word 插入字首樹中
void insert(string word);
// 返回字首樹中字串 word 的例項個數
int search(string word);
// 返回字首樹中以 prefix 為字首的字串個數
int prefixNumber(string prefix);
// 從字首樹中移除字串 word
void remove(string word);

動態結構實現

  • 拉跨,不推薦
#include <vector>
#include <cstdlib>
#include <ctime>
#include <string>
#include <unordered_map>

using namespace std;

class TrieNode {
public:
    int pass;
    int end;
    // 也可以用 unordered_map 實現
    vector<TrieNode *> nexts;

    TrieNode() {
        pass = 0;
        end = 0;
        nexts.resize(26);
    }
};

class Trie {
private:
    TrieNode *root;
public:
    Trie() {
        root = new TrieNode;
    }

    // 將字串 word 插入字首樹中
    void insert(string word) {
        TrieNode *cur = root;
        cur->pass++;
        for (int i = 0, path; i < word.length(); ++i) {
            path = word[i] - 'a';
            if (cur->nexts[path] == nullptr) cur->nexts[path] = new TrieNode();
            cur = cur->nexts[path];
            cur->pass++;
        }
        cur->end++;
    }

    // 返回字首樹中字串 word 的例項個數
    int search(string word) {
        TrieNode *cur = root;
        for (int i = 0, path; i < word.length(); ++i) {
            path = word[i] - 'a';
            if (cur->nexts[path] == nullptr) return 0;
            cur = cur->nexts[path];
        }
        return cur->end;
    }

    // 返回字首樹中以 prefix 為字首的字串個數
    int prefixNumber(string prefix) {
        TrieNode *cur = root;
        for (int i = 0, path; i < prefix.length(); ++i) {
            path = prefix[i] - 'a';
            if (cur->nexts[path] == nullptr) return 0;
            cur = cur->nexts[path];
        }
        return cur->pass;
    }

    // 從字首樹中移除字串 word
    void remove(string word) {
        if (search(word) <= 0) return;
        TrieNode *cur = root;
        cur->pass--;
        for (int i = 0, path; i < word.length(); ++i) {
            path = word[i] - 'a';
            cur->nexts[path]--;
            if (cur->nexts[path]->pass == 0) {
                // 此處省略釋放記憶體空間
                cur->nexts[path] = nullptr;
                return;
            }
            cur = cur->nexts[path];
        }
        cur->end--;
    }
};

靜態結構實現

NC124 字典樹的實現

#include <vector>
#include <string>
#include <iostream>

using namespace std;

class Trie {
public:
    vector<vector<int>> tree;
    vector<int> pass;
    vector<int> end;
    int cnt;
    const int maxN = 150001;

    void build() {
        cnt = 1;
        tree.resize(maxN, vector<int>(26));
        pass.resize(maxN, 0);
        end.resize(maxN, 0);
    }

    void insert(string word) {
        int cur = 1;
        pass[cur]++;
        for (int i = 0, path; i < word.length(); ++i) {
            path = word[i] - 'a';
            if (tree[cur][path] == 0)
                tree[cur][path] = ++cnt;
            cur = tree[cur][path];
            pass[cur]++;
        }
        end[cur]++;
    }

    int search(string word) {
        int cur = 1;
        for (int i = 0, path; i < word.length(); ++i) {
            path = word[i] - 'a';
            if (tree[cur][path] == 0) return 0;
            cur = tree[cur][path];
        }
        return end[cur];
    }

    int prefixNumber(string word) {
        int cur = 1;
        for (int i = 0, path; i < word.length(); ++i) {
            path = word[i] - 'a';
            if (tree[cur][path] == 0) return 0;
            cur = tree[cur][path];
        }
        return pass[cur];
    }

    void remove(string word) {
        if (search(word) <= 0) return;
        int cur = 1;
        for (int i = 0, path; i < word.length(); ++i) {
            path = word[i] - 'a';
            if (--pass[tree[cur][path]] == 0) {
                tree[cur][path] = 0;
                return;
            }
            cur = tree[cur][path];
        }
        end[cur]--;
    }
};

int main() {
    int m;
    cin >> m;

    Trie trie;
    trie.build();

    for (int i = 0; i < m; ++i) {
        int opt;
        string word;
        cin >> opt >> word;
        switch (opt) {
            case 1:
                trie.insert(word);
                break;
            case 2:
                trie.remove(word);
                break;
            case 3:
                cout << (trie.search(word) > 0 ? "YES" : "NO") << endl;
                break;
            case 4:
                cout << trie.prefixNumber(word) << endl;
                break;
        }
    }
}

接頭密匙

#include <vector>
#include <string>
#include <iostream>

using namespace std;

class Trie {
public:
    vector<vector<int>> tree;
    vector<int> pass;
    vector<int> end;
    int cnt;
    const int maxN = 2000001;

    void build() {
        cnt = 1;
        tree.resize(maxN, vector<int>(12));
        pass.resize(maxN, 0);
        end.resize(maxN, 0);
    }

    // '0' ~ '9' 10個 0~9
    // '#' 10
    // '-' 11
    int getPath(char ch) {
        if (ch == '#') {
            return 10;
        } else if (ch == '-') {
            return 11;
        } else {
            return ch - '0';
        }
    }

    void insert(string word) {
        int cur = 1;
        pass[cur]++;
        for (int i = 0, path; i < word.length(); ++i) {
            path = getPath(word[i]);
            if (tree[cur][path] == 0)
                tree[cur][path] = ++cnt;
            cur = tree[cur][path];
            pass[cur]++;
        }
        end[cur]++;
    }

    int prefixNumber(string word) {
        int cur = 1;
        for (int i = 0, path; i < word.length(); ++i) {
            path = getPath(word[i]);
            if (tree[cur][path] == 0) return 0;
            cur = tree[cur][path];
        }
        return pass[cur];
    }
};

class Solution {
public:
    vector<int> countConsistentKeys(vector<vector<int>> &b, vector<vector<int>> &a) {
        Trie trie;
        trie.build();

        // 把生成的字串加到字首樹
        for (const auto &item: a) {
            string str = "";
            // 用 # 隔斷數字,例如 [3,6,50,10] -> "3#44#-40#"
            for (int i = 1; i < item.size(); ++i)
                str.append(to_string(item[i] - item[i - 1]) + "#");
            trie.insert(str);
        }

        vector<int> res;
        for (const auto &item: b) {
            string str = "";
            for (int i = 1; i < item.size(); ++i)
                str.append(to_string(item[i] - item[i - 1]) + "#");
            res.emplace_back(trie.prefixNumber(str));
        }
        return res;
    }
};

421. 陣列中兩個數的最大異或值

#include <vector>
#include <string>
#include <iostream>

using namespace std;

vector<vector<int>> tree;
int cnt;
// 數字只需要從哪一位開始考慮
int high;

// 計算前導 0 的個數
int countLeadingZeros(int i) {
    if (i <= 0) return i == 0 ? 32 : 0;
    // 最多 31 個前導 0
    int n = 31;
    // 大於等於 2^16
    if (i >= 1 << 16) {
        // 低 16 位不用再考慮了,因為更高位存在 1
        // 此時最多 15 個前導 0
        n -= 16;
        // 邏輯右移 16 位,折半
        i = (unsigned) i >> 16;
    }
    if (i >= 1 << 8) {
        n -= 8;
        i = (unsigned) i >> 8;
    }
    if (i >= 1 << 4) {
        n -= 4;
        i = (unsigned) i >> 4;
    }
    if (i >= 1 << 2) {
        n -= 2;
        i = (unsigned) i >> 2;
    }
    return n - ((unsigned) i >> 1);
}

class Solution {
public:
    void insert(int num) {
        int cur = 1;
        // 從 high 開始往低位考慮
        for (int i = high, state; i >= 0; i--) {
            // 判斷 high 位是 0 還是 1
            state = (num >> i) & 1;
            if (tree[cur][state] == 0)
                tree[cur][state] = ++cnt;
            cur = tree[cur][state];
        }
    }

    void build(vector<int> &nums) {
        tree.resize(3000001, vector<int>(2, 0));
        cnt = 1;
        int m = INT_MIN;
        for (int num: nums)
            m = max(num, m);
        high = 31 - countLeadingZeros(m);
        // 構建字首樹
        for (int num: nums)
            insert(num);
    }

    int maxXor(int num) {
        int res = 0;
        int cur = 1;
        for (int i = high, state, want; i >= 0; i--) {
            state = (num >> i) & 1;
            // want: num 第 i 位希望遇到的狀態
            want = state ^ 1;
            if (tree[cur][want] == 0) {
                // 得不到想要的,就恢復
                want ^= 1;
            }
            // want 此時為實際往下走的路
            res |= (state ^ want) << i;
            cur = tree[cur][want];
        }
        return res;
    }

    void clear() {
        for (int i = 1; i <= cnt; i++)
            tree[i][0] = tree[i][1] = 0;
    }

    int findMaximumXOR(vector<int> &nums) {
        build(nums);
        int res = 0;
        for (int num: nums)
            res = max(res, maxXor(num));
        clear();
        return res;
    }
};
#include <vector>
#include <string>
#include <iostream>
#include <unordered_set>
#include <algorithm>

using namespace std;

// todo: 計算前導 0 的個數
int countLeadingZeros(int i) {
    if (i <= 0) return i == 0 ? 32 : 0;
    // 最多 31 個前導 0
    int n = 31;
    // 大於等於 2^16
    if (i >= 1 << 16) {
        // 低 16 位不用再考慮了,因為更高位存在 1
        // 此時最多 15 個前導 0
        n -= 16;
        // 邏輯右移 16 位,折半
        i = (unsigned) i >> 16;
    }
    if (i >= 1 << 8) {
        n -= 8;
        i = (unsigned) i >> 8;
    }
    if (i >= 1 << 4) {
        n -= 4;
        i = (unsigned) i >> 4;
    }
    if (i >= 1 << 2) {
        n -= 2;
        i = (unsigned) i >> 2;
    }
    return n - ((unsigned) i >> 1);
}


class Solution {
public:
    int findMaximumXOR(vector<int> &nums) {
        int m = INT_MIN;
        for (int num: nums) m = max(num, m);
        int res = 0;
        unordered_set<int> set;
        for (int i = 31 - countLeadingZeros(m); i >= 0; i--) {
            // res : 31....i+1 已經達成的目標
            int better = res | (1 << i);
            set.clear();
            for (int num: nums) {
                // num : 31.....i 這些狀態保留,剩下全成0
                num = (num >> i) << i;
                set.emplace(num);
                // num ^ 某狀態是否能達成 better 目標,就在 set 中找某狀態 : better ^ num
                if (set.find(better ^ num) != set.end()) {
                    res = better;
                    break;
                }
            }
        }
        return res;
    }
};

212. 單詞搜尋 II

#include <vector>
#include <string>
#include <unordered_set>
#include <cstring>

using namespace std;

class Solution {
public:
    // todo
    static vector<string> findWords(vector<vector<char>> &board, vector<string> &words) {
        build(words);
        vector<string> ans;
        for (int i = 0; i < board.size(); i++) {
            for (int j = 0; j < board[0].size(); j++) {
                dfs(board, i, j, 1, ans);
            }
        }
        clear();
        return ans;
    }

private:
    static const int MAXN = 10001;
    static int tree[MAXN][26];
    static int pass[MAXN];
    static string end[MAXN];
    static int cnt;

    static void build(vector<string> &words) {
        cnt = 1;
        for (const string &word: words) {
            int cur = 1;
            pass[cur]++;
            for (char c: word) {
                int path = c - 'a';
                if (tree[cur][path] == 0) {
                    tree[cur][path] = ++cnt;
                }
                cur = tree[cur][path];
                pass[cur]++;
            }
            end[cur] = word;
        }
    }

    static void clear() {
        for (int i = 1; i <= cnt; i++) {
            memset(tree[i], 0, sizeof(tree[i]));
            pass[i] = 0;
            end[i].clear();
        }
    }

    // board : 二維網格
    // i,j : 此時來到的格子位置,i行、j列
    // t : 字首樹的編號
    // List<String> ans : 收集到了哪些字串,都放入ans
    // 返回值 : 收集到了幾個字串
    static int dfs(vector<vector<char>> &board, int i, int j, int t, vector<string> &ans) {
        if (i < 0 || i == board.size() || j < 0 || j == board[0].size() || board[i][j] == 0) {
            return 0;
        }
        // 不越界 且 不是回頭路
        // 用tmp記錄當前字元
        char tmp = board[i][j];
        // 路的編號
        // a -> 0
        // b -> 1
        // ...
        // z -> 25
        int road = tmp - 'a';
        t = tree[t][road];
        if (pass[t] == 0) {
            return 0;
        }
        // i,j位置有必要來
        // fix :從當前i,j位置出發,一共收集到了幾個字串
        int fix = 0;
        if (!end[t].empty()) {
            fix++;
            ans.push_back(end[t]);
            end[t].clear();
        }
        // 把i,j位置的字元,改成0,後續的過程,是不可以再來到i,j位置的!
        board[i][j] = 0;
        fix += dfs(board, i - 1, j, t, ans);
        fix += dfs(board, i + 1, j, t, ans);
        fix += dfs(board, i, j - 1, t, ans);
        fix += dfs(board, i, j + 1, t, ans);
        pass[t] -= fix;
        board[i][j] = tmp;
        return fix;
    }
};

int Solution::tree[MAXN][26];
int Solution::pass[MAXN];
string Solution::end[MAXN];
int Solution::cnt;

相關文章