由來
無論在使用LLM大模型時,還是使用bert等傳統的模型,對字串進行編碼都是必要的,只有經過編碼後的字串才能參與到後面的模型計算。
以下是在transformers庫下的編碼方式,無論是什麼模型,AutoTokenizer隱藏了很多細節:
query = 'hello'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
inputs = tokenizer.encode(query)
好處是在使用時不用管tokenizer的底層實現,只需要看看配置就可以了,但當需要自己去實現端到端的LLM推理時,就有點摸不著頭腦了。
拆解transformers
因為transformers的庫是python編寫的,所以我們可以直接扒開裡面的原始碼,看看他們的具體實現,這裡以網易的BCE-Embedding為例,看看裡面都做了些什麼。
首先看到BCE-Embedding是在XLMRobertaModel下重新訓練了語料,訓練之後的長度是250005,包含了250000個正常token和5個特殊token。
編碼方式
這5個特殊token可以在模型初始化時看到:
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
那麼正常token是怎麼儲存的呢,可以看到其內部使用的是google的sentencepiece來儲存的:
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
需要注意的是,XLMRobertaModel是fairseq下的模型,那麼其特殊字元的加入位置是不一樣的,另外XLMRobertaModel在末尾加了<mask>字元
Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
---|---|---|---|---|---|---|---|---|---|---|
fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-' |
spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' |
計算流程
一個query字串近來的流程是怎樣的呢,首先經過query會經過分詞變成多個token piece,具體分詞演算法是bpe,然後模型字典中找token piece對應的id,當然由於特殊token是後來加的,所以優先尋找特殊token。
以下是原始碼中的具體實現,_tokenize方法將字串分解為多個piece,_convert_token_to_id將對應的piece轉換為對應的id,解碼則是反過來的過程,邏輯是一樣的:
def _tokenize(self, text: str) -> List[str]:
# TODO check if the t5/llama PR also applies here
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)
# Need to return unknown token if the SP model returned 0
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
僅僅到這一步,我們便知道,XLMRobertaModel是包了一層特殊token的sentencepiece,在實現時只需要實現一個字典代替sentencepiece,剩餘的把特殊token加入即可。
sentencepiece的實現
這裡我們不免好奇,sentencepiece是怎麼實現的呢?
BPE演算法
通常情況下,Tokenizer有三種粒度:word/char/subword
- word: 按照詞進行分詞,如:
Today is sunday
. 則根據空格或標點進行分割[today, is, sunday, .]
- character:按照單字元進行分詞,就是以char為最小粒度。 如:
Today is sunday.
則會分割成[t, o, d,a,y, .... ,s,u,n,d,a,y, .]
- subword:按照詞的subword進行分詞。如:
Today is sunday.
則會分割成[to, day,is , s,un,day, .]
演算法流程:
- 確定詞表大小,即subword的最大個數V;
- 在每個單詞最後新增一個,並且統計每個單詞出現的頻率;
- 將所有單詞拆分為單個字元,構建出初始的詞表,此時詞表的subword其實就是字元;
- 挑出頻次最高的字元對,比如說
t
和h
組成的th
,將新字元加入詞表,然後將語料中所有該字元對融合(merge),即所有t
和h
都變為th
。新字元依然可以參與後續的 merge以變成更長的piece。 - 重複3,4的操作,直到詞表中單詞數量達到預設的閾值V或者下一個字元對的頻數為1;
sentencepiece中的原始碼實現:
其中SymbolPair代表了piece對,左右分別代表了合併的piece的來源,Symbol帶代表了一個piece,這裡的Symbol採用了類似了鏈式的方法,是為了避免piece合併後的記憶體移動,只需要用prev和next記錄合併前後的鄰居即可,尋找最大頻率的合併piece使用priority_queue的方式。
這裡的實現是上文演算法的345部分,其中12在可以在前置處理過程中得到。
// ref: https://github.com/google/sentencepiece/blob/master/src/bpe_model.cc
Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized,
float alpha) {
// util class begin
struct SymbolPair {
int left; // left index of this pair
int right; // right index of this pair
float score; // score of this pair. large is better.
size_t size; // length of this piece
};
class SymbolPairComparator {
public:
const bool operator()(SymbolPair *h1, SymbolPair *h2) {
return (h1->score < h2->score ||
(h1->score == h2->score && h1->left > h2->left));
}
};
struct Symbol {
int prev; // prev index of this symbol. -1 for BOS.
int next; // next index of tihs symbol. -1 for EOS.
bool freeze = false; // this symbol is never be merged.
string_view_ piece;
};
// util class end
using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
SymbolPairComparator>;
Agenda agenda;
std::vector<Symbol> symbols;
symbols.reserve(normalized.size());
// Reverse merge rules. key: merged symbol, value: pair of original symbols.
std::unordered_map<string_view_, std::pair<string_view_, string_view_>>
rev_merge;
// SymbolPair holder.
std::vector<std::unique_ptr<SymbolPair>> symbol_pair_holder;
// Lookup new symbol pair at [left, right] and inserts it to agenda.
}
- 將所有的normalized之後的string轉換為單個字元:
while (!normalized.empty()) {
Symbol s;
// const int mblen = matcher_->PrefixMatch(normalized, &s.freeze);
int mblen =
std::min<int>(normalized.size(), one_char_len(normalized.data()));
s.piece = string_view_(normalized.data(), mblen);
s.prev = index == 0 ? -1 : index - 1;
normalized.remove_prefix(mblen);
s.next = normalized.empty() ? -1 : index + 1;
++index;
symbols.emplace_back(s);
}
這裡判斷單個欄位的長度,取了個巧, (src & 0xFF) >> 4會進行8位截斷,然後右移4位,將普通的Ascii過濾,特殊的字元會被編碼位多個位元組:
static inline size_t one_char_len(const char *src) {
return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4];
}
- 記錄當前piece中可能存在的piece對,並將可能的piece對合並
對應的是關鍵函式是MaybeAddNewSymbolPair,該函式是匿名函式,會嘗試搜尋合併相鄰的兩個piece,這裡嘗試合併兩個piece,如果能夠合併,就加入到agenda佇列中,symbol_pair_holder也會儲存以便下一次合併:
auto MaybeAddNewSymbolPair = [this, &symbol_pair_holder, &symbols, &agenda,
&rev_merge](int left, int right) {
if (left == -1 || right == -1 || symbols[left].freeze ||
symbols[right].freeze) {
return;
}
const string_view_ piece(symbols[left].piece.data(),
symbols[left].piece.size() +
symbols[right].piece.size());
std::string piece_str(piece.to_string());
const auto it = pieces_.find(piece_str);
if (it == pieces_.end()) {
return;
}
symbol_pair_holder.emplace_back(new SymbolPair);
auto *h = symbol_pair_holder.back().get();
h->left = left;
h->right = right;
h->score = get_score(it->second);
h->size = piece.size();
agenda.push(h);
// Makes `rev_merge` for resegmentation.
if (is_unused(it->second)) {
rev_merge[piece] =
std::make_pair(symbols[left].piece, symbols[right].piece);
}
};
在迴圈中,找到agenda中分數最高的,這裡有一定機率的dropout,大概是10%,接著是合併agenda中最高pair到其左邊,然後更新symbols,有點類似與刪除連結串列的操作,只不過這裡是改變前後鄰居,然後迴圈合併前後鄰居:
// Main loop.
while (!agenda.empty()) {
SymbolPair *top = agenda.top();
agenda.pop();
// `top` is no longer available.
if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
top->size) {
continue;
}
if (skip_merge())
continue;
// Replaces symbols with `top` rule.
symbols[top->left].piece = string_view_(
symbols[top->left].piece.data(),
symbols[top->left].piece.size() + symbols[top->right].piece.size());
// Updates prev/next pointers.
symbols[top->left].next = symbols[top->right].next;
if (symbols[top->right].next >= 0) {
symbols[symbols[top->right].next].prev = top->left;
}
symbols[top->right].piece = string_view_("");
// Adds new symbol pairs which are newly added after symbol replacement.
MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
}
相比與理想的演算法,實際實現中多了一步,即將is_unused的id的pair再重新拆回去,具體的逆向字典則在merge時儲存:
std::function<void(string_view_, EncodeResult *)> resegment;
resegment = [this, &resegment, &rev_merge](string_view_ w,
EncodeResult *output) -> void {
std::string w_str(w.to_string());
const int id = piece_to_id(w_str);
// std::cout << "piece: " << w << ", id = " << id << std::endl;
if (id == -1 || !is_unused(id)) {
output->emplace_back(w, id);
return;
}
const auto p = rev_merge.find(w);
if (p == rev_merge.end()) {
// This block will never be called, as `rev_merge` stores all the
// resegmentation info for unused id.
output->emplace_back(w, id);
return;
}
// Recursively resegment left and right symbols.
resegment(p->second.first, output);
resegment(p->second.second, output);
};
經過以上的演算法轉換之後,便可以將string轉換為對應的ids了。
題外
除了BPE的編碼方式,還有BBPE、WordPiece、Unigram等不同的方式,除了在具體處理方式上的不同,總體結構上是大同小異的。