[SentencePiece]Tokenizer的原理與實現

wildkid1024發表於2024-08-26

由來

無論在使用LLM大模型時,還是使用bert等傳統的模型,對字串進行編碼都是必要的,只有經過編碼後的字串才能參與到後面的模型計算。
以下是在transformers庫下的編碼方式,無論是什麼模型,AutoTokenizer隱藏了很多細節:

query = 'hello'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
inputs = tokenizer.encode(query)

好處是在使用時不用管tokenizer的底層實現,只需要看看配置就可以了,但當需要自己去實現端到端的LLM推理時,就有點摸不著頭腦了。

拆解transformers

因為transformers的庫是python編寫的,所以我們可以直接扒開裡面的原始碼,看看他們的具體實現,這裡以網易的BCE-Embedding為例,看看裡面都做了些什麼。
首先看到BCE-Embedding是在XLMRobertaModel下重新訓練了語料,訓練之後的長度是250005,包含了250000個正常token和5個特殊token。

編碼方式

這5個特殊token可以在模型初始化時看到:

bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",

那麼正常token是怎麼儲存的呢,可以看到其內部使用的是google的sentencepiece來儲存的:

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))

需要注意的是,XLMRobertaModel是fairseq下的模型,那麼其特殊字元的加入位置是不一樣的,另外XLMRobertaModel在末尾加了<mask>字元

Vocab 0 1 2 3 4 5 6 7 8 9
fairseq '<s>' '<pad>' '</s>' '<unk>' ',' '.' '▁' 's' '▁de' '-'
spm '<unk>' '<s>' '</s>' ',' '.' '▁' 's' '▁de' '-' '▁a'

計算流程

一個query字串近來的流程是怎樣的呢,首先經過query會經過分詞變成多個token piece,具體分詞演算法是bpe,然後模型字典中找token piece對應的id,當然由於特殊token是後來加的,所以優先尋找特殊token。
以下是原始碼中的具體實現,_tokenize方法將字串分解為多個piece,_convert_token_to_id將對應的piece轉換為對應的id,解碼則是反過來的過程,邏輯是一樣的:

    def _tokenize(self, text: str) -> List[str]:
        # TODO check if the t5/llama PR also applies here
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        if token in self.fairseq_tokens_to_ids:
            return self.fairseq_tokens_to_ids[token]
        spm_id = self.sp_model.PieceToId(token)

        # Need to return unknown token if the SP model returned 0
        return spm_id + self.fairseq_offset if spm_id else self.unk_token_id

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        if index in self.fairseq_ids_to_tokens:
            return self.fairseq_ids_to_tokens[index]
        return self.sp_model.IdToPiece(index - self.fairseq_offset)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (strings for sub-words) in a single string."""
        out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
        return out_string

僅僅到這一步,我們便知道,XLMRobertaModel是包了一層特殊token的sentencepiece,在實現時只需要實現一個字典代替sentencepiece,剩餘的把特殊token加入即可。

sentencepiece的實現

這裡我們不免好奇,sentencepiece是怎麼實現的呢?

BPE演算法

通常情況下,Tokenizer有三種粒度:word/char/subword

  • word: 按照詞進行分詞,如: Today is sunday. 則根據空格或標點進行分割[today, is, sunday, .]
  • character:按照單字元進行分詞,就是以char為最小粒度。 如:Today is sunday. 則會分割成[t, o, d,a,y, .... ,s,u,n,d,a,y, .]
  • subword:按照詞的subword進行分詞。如:Today is sunday. 則會分割成[to, day,is , s,un,day, .]

演算法流程:

  1. 確定詞表大小,即subword的最大個數V;
  2. 在每個單詞最後新增一個,並且統計每個單詞出現的頻率;
  3. 將所有單詞拆分為單個字元,構建出初始的詞表,此時詞表的subword其實就是字元;
  4. 挑出頻次最高的字元對,比如說th組成的th,將新字元加入詞表,然後將語料中所有該字元對融合(merge),即所有th都變為th。新字元依然可以參與後續的 merge以變成更長的piece。
  5. 重複3,4的操作,直到詞表中單詞數量達到預設的閾值V或者下一個字元對的頻數為1;

sentencepiece中的原始碼實現:

其中SymbolPair代表了piece對,左右分別代表了合併的piece的來源,Symbol帶代表了一個piece,這裡的Symbol採用了類似了鏈式的方法,是為了避免piece合併後的記憶體移動,只需要用prev和next記錄合併前後的鄰居即可,尋找最大頻率的合併piece使用priority_queue的方式。

這裡的實現是上文演算法的345部分,其中12在可以在前置處理過程中得到。

// ref: https://github.com/google/sentencepiece/blob/master/src/bpe_model.cc
Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized,
                                                      float alpha) {
  // util class begin
  struct SymbolPair {
    int left;    // left index of this pair
    int right;   // right index of this pair
    float score; // score of this pair. large is better.
    size_t size; // length of this piece
  };

  class SymbolPairComparator {
  public:
    const bool operator()(SymbolPair *h1, SymbolPair *h2) {
      return (h1->score < h2->score ||
              (h1->score == h2->score && h1->left > h2->left));
    }
  };

  struct Symbol {
    int prev;            // prev index of this symbol. -1 for BOS.
    int next;            // next index of tihs symbol. -1 for EOS.
    bool freeze = false; // this symbol is never be merged.
    string_view_ piece;
  };
  // util class end

  using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
                                     SymbolPairComparator>;
  Agenda agenda;
  std::vector<Symbol> symbols;
  symbols.reserve(normalized.size());
  // Reverse merge rules. key: merged symbol, value: pair of original symbols.
  std::unordered_map<string_view_, std::pair<string_view_, string_view_>>
      rev_merge;
  // SymbolPair holder.
  std::vector<std::unique_ptr<SymbolPair>> symbol_pair_holder;
  // Lookup new symbol pair at [left, right] and inserts it to agenda.
}
  1. 將所有的normalized之後的string轉換為單個字元:
  while (!normalized.empty()) {
    Symbol s;
    // const int mblen = matcher_->PrefixMatch(normalized, &s.freeze);
    int mblen =
        std::min<int>(normalized.size(), one_char_len(normalized.data()));
    s.piece = string_view_(normalized.data(), mblen);
    s.prev = index == 0 ? -1 : index - 1;
    normalized.remove_prefix(mblen);
    s.next = normalized.empty() ? -1 : index + 1;
    ++index;
    symbols.emplace_back(s);
  }

這裡判斷單個欄位的長度,取了個巧, (src & 0xFF) >> 4會進行8位截斷,然後右移4位,將普通的Ascii過濾,特殊的字元會被編碼位多個位元組:

static inline size_t one_char_len(const char *src) {
  return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4];
}
  1. 記錄當前piece中可能存在的piece對,並將可能的piece對合並

對應的是關鍵函式是MaybeAddNewSymbolPair,該函式是匿名函式,會嘗試搜尋合併相鄰的兩個piece,這裡嘗試合併兩個piece,如果能夠合併,就加入到agenda佇列中,symbol_pair_holder也會儲存以便下一次合併:

auto MaybeAddNewSymbolPair = [this, &symbol_pair_holder, &symbols, &agenda,
                                &rev_merge](int left, int right) {
    if (left == -1 || right == -1 || symbols[left].freeze ||
        symbols[right].freeze) {
      return;
    }
    const string_view_ piece(symbols[left].piece.data(),
                             symbols[left].piece.size() +
                                 symbols[right].piece.size());
    std::string piece_str(piece.to_string());
    const auto it = pieces_.find(piece_str);
    if (it == pieces_.end()) {
      return;
    }
    symbol_pair_holder.emplace_back(new SymbolPair);
    auto *h = symbol_pair_holder.back().get();
    h->left = left;
    h->right = right;
    h->score = get_score(it->second);
    h->size = piece.size();
    agenda.push(h);

    // Makes `rev_merge` for resegmentation.
    if (is_unused(it->second)) {
      rev_merge[piece] =
          std::make_pair(symbols[left].piece, symbols[right].piece);
    }
  };

在迴圈中,找到agenda中分數最高的,這裡有一定機率的dropout,大概是10%,接著是合併agenda中最高pair到其左邊,然後更新symbols,有點類似與刪除連結串列的操作,只不過這裡是改變前後鄰居,然後迴圈合併前後鄰居:

// Main loop.
  while (!agenda.empty()) {
    SymbolPair *top = agenda.top();
    agenda.pop();

    // `top` is no longer available.
    if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
        symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
            top->size) {
      continue;
    }

    if (skip_merge())
      continue;
    // Replaces symbols with `top` rule.
    symbols[top->left].piece = string_view_(
        symbols[top->left].piece.data(),
        symbols[top->left].piece.size() + symbols[top->right].piece.size());

    // Updates prev/next pointers.
    symbols[top->left].next = symbols[top->right].next;
    if (symbols[top->right].next >= 0) {
      symbols[symbols[top->right].next].prev = top->left;
    }
    symbols[top->right].piece = string_view_("");

    // Adds new symbol pairs which are newly added after symbol replacement.
    MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
    MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
  }

相比與理想的演算法,實際實現中多了一步,即將is_unused的id的pair再重新拆回去,具體的逆向字典則在merge時儲存:

  std::function<void(string_view_, EncodeResult *)> resegment;
  resegment = [this, &resegment, &rev_merge](string_view_ w,
                                             EncodeResult *output) -> void {
    std::string w_str(w.to_string());
    const int id = piece_to_id(w_str);
    // std::cout << "piece: " << w << ", id = " << id << std::endl;
    if (id == -1 || !is_unused(id)) {
      output->emplace_back(w, id);
      return;
    }
    const auto p = rev_merge.find(w);
    if (p == rev_merge.end()) {
      // This block will never be called, as `rev_merge` stores all the
      // resegmentation info for unused id.
      output->emplace_back(w, id);
      return;
    }
    // Recursively resegment left and right symbols.
    resegment(p->second.first, output);
    resegment(p->second.second, output);
  };

經過以上的演算法轉換之後,便可以將string轉換為對應的ids了。

題外

除了BPE的編碼方式,還有BBPE、WordPiece、Unigram等不同的方式,除了在具體處理方式上的不同,總體結構上是大同小異的。

相關文章