完全解析!Bert & Transformer 閱讀理解原始碼詳解

红色石头發表於2021-07-19

接上一篇:

你所不知道的 Transformer!

超詳細的 Bert 文字分類原始碼解讀 | 附原始碼

中文情感分類單標籤

參考論文:

https://arxiv.org/abs/1706.03762

https://arxiv.org/abs/1810.04805

在本文中,我將以run_squad.py以及SQuAD資料集為例介紹閱讀理解的原始碼,官方程式碼基於tensorflow-gpu 1.x,若為tensorflow 2.x版本,會有各種錯誤,建議切換版本至1.14。

當然,註釋好的原始碼在這裡:

https://github.com/sherlcok314159/ML/tree/main/nlp/code

章節

  • Demo傳參
  • 資料篇

    • 番外句子分類
    • 創造例項

    • 例項轉換

  • 模型構造

  • 寫入預測

Demo傳參

python bert/run_squad.py \
  --vocab_file=uncased_L-12_H-768_A-12/vocab.txt \
  --bert_config_file=uncased_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=uncased_L-12_H-768_A-12/bert_model.ckpt \
  --do_train=True \
  --train_file=SQUAD_DIR/train-v2.0.json \
  --train_batch_size=8 \
  --learning_rate=3e-5 \
  --num_train_epochs=1.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=/tmp/squad2.0_base/ \
  --version_2_with_negative=True

閱讀原始碼最重要的一點不是拿到就讀,而是跑通原始碼裡面的小demo,因為你跑通demo就意味著你對程式碼的一些基礎邏輯和引數有了一定的瞭解。

前面的引數都十分常規,如果不懂,建議看我的文字分類的講解。這裡講一下比較特殊的最後一個引數,我們做的任務是閱讀理解,如果有答案缺失,在SQuAD1.0是不可以的,但是在SQuAD允許,這也就是True的意思。

需要注意,不同人的檔案路徑都是不一樣的,你不能照搬我的,要改成自己的路徑。

資料篇

其實閱讀理解任務模型是跟文字分類幾乎是一樣的,大的差異在於兩者對於資料的處理,所以本篇文章重點在於如何將原生的資料轉換為閱讀理解任務所能接受的資料,至於模型構造篇,請看文字分類:

https://github.com/sherlcok314159/ML/blob/main/nlp/tasks/text.md

番外句子分類

想必很多人看到SquadExample類的repr方法都很疑惑,這裡處理好一個example,為什麼後面還要進行處理?看英文註釋會發現這個類其實跟閱讀理解沒關係,它只是處理之後對於句子分類任務的,自然在run_squad.py裡面沒被呼叫。repr方法只是在有start_position的時候進行字串的拼接。

創造例項

用於訓練的資料集是json檔案,需要用json庫讀入。

訓練集的樣式如下,可見data是最外層的

{
    "data": [
        {
            "title": "University_of_Notre_Dame",
            "paragraphs": [
                {
                    "context": "Architecturally, the school has a Catholic character.",
                    "qas": [
                        {
                            "answers": [
                                {
                                    "answer_start": 515,
                                    "text": "Saint Bernadette Soubirous"
                                }
                            ],
                            "question": "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?",
                            "id": "5733be284776f41900661182"
                        }
                    ]
                }
            ]
        },
        {
            "title":"...",
            "paragraphs":[
                {
                    "context":"...",
                    "qas":[
                        {
                            "answers":[
                                {
                                    "answer_start":..,
                                    "text":"...",
                                }
                            ],
                            "question":"...",
                            "id":"..."
                        },
                    ]
                }
            ]
        }
    ]
}

input_data是一個大列表,然後每一個元素樣式如下

{'paragraphs': [{...}, {...}, {...}, {...}, {...}, {...}, {...}, {...}, {...}, ...], 'title': 'University_of_Notre_Dame'}

is_whitespace方法是用來判斷是否是一個空格,在切分字元然後加入doc_tokens會用到。

然後我們層層剝開,然後遍歷context的內容,它是一個字串,所以遍歷的時候會遍歷每一個字母,字元會被進行判斷,如果是空格,則加入doc_tokens,char_to_word_offset表示切分後的索引列表,每一個元素表示一個詞有幾個字元組成。

切分後的doc_tokens會去掉空白部分,同時會包括英文逗號。一個單詞會有很多字元,每個字元對應的索引會存在char_to_word_offset,例如,前面都是0,代表這些字元都是第一個單詞的,所以都是0,換句話說就是第一個單詞很長。

doc_tokens = ['Architecturally,', 'the', 'school', 'has', 'a', 'Catholic', 'character.', 'Atop', 'the',"..."]

char_to_word_offset = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]

接下來進行qas內容的遍歷,每個元素稱為qa,進行id和question內容的分配,後面都是初始化一些引數

qa裡面還有一個is_impossible,用於判斷是否有答案

確保有答案之後,剛剛讀入了問題,現在讀入與答案相關的部分,讀入的時候注意start_position和end_position是相對於doc_tokens的

接下來對答案部分進行雙重檢驗,actual_text是根據doc_tokens和始末位置拼接好的內容,然後對orig_answer_text進行空格切分,最後用find方法判斷orig_answer_text是否被包含在actual_text裡面。

這個是針對is_impossible來說的,如果沒有答案,則把始末位置全部變成-1。

然後將example變成SquadExample的例項化物件,將example加入大列表——examples並返回,至此例項建立完成。

例項轉換

把json檔案變成例項之後,我們還差一步便可以把資料塞進模型進行訓練了,那就是將例項轉化為變數。

先對question_text進行簡單的空格切分變為query_tokens

如果問題過長,就進行截斷操作

接下來對doc_tokens進行空格切分以及詞切分,變成all_doc_tokens,需要注意的是orig_to_tok_index代表的是doc_tokens在all_doc_tokens的索引,取最近的一個,而tok_to_orig_index代表的是all_doc_tokens在doc_tokens索引

對tok_start_position和tok_end_position進行初始化,記住,這兩個是相對於all_doc_tokens來說的,一定要與start_position和end_position區分開來,它們是相對於doc_tokens來說的

接下來先介紹_improve_answer_span方法,這個方法是用來處理特殊的情況的,舉個例子,假如說你的文字是”The Japanese electronics industry is the lagest in the world.”,你的問題是”What country is the top exporter of electornics?” 那答案其實應該是Japan,可是呢,你用空格和詞切分的時候會發現Japanese已經在詞表中可查,這意味著不會對它進行再切分,會直接將它返回,這種情況下可能需要這個方法救場。

因為是監督學習,答案已經給出,所以呢,這個方法乾的事情就是詞切分後的tokens進行再一次切分,如果發現切分之後會有更好的答案,就返回新的始末點,否則就返回原來的。

對tok_start_position和tok_end_position進行進一步賦值

計算max_tokens_for_doc,與文字分類類似,需要減去[CLS]和兩個[SEP]的位置,這裡不同的是還要減去問題的長度,因為這裡算的是文字的長度。

tokens = [CLS] query tokens [SEP] context [SEP]

很多時候文章長度大於maximum_sequence_length的時候,這個時候我們要對文章進行切片處理,把它按照一定長度進行切分,每一個切片稱為一個doc_span,start代表從哪開始,length代表一個的長度。

doc_spans儲存很多個doc_span。這裡對視窗的長度有所限制,規定了start_offset不能比doc_stride大,這是第二個視窗的起點,從這個角度或許可以理解doc_stride代表平滑的長度。

接下來的操作跟文字分類有些類似,新增[CLS],然後新增問題和[SEP],這些在segment_ids裡面都為0。

下面講_check_is_max_context方法,這個方法是用來判斷某個詞是否具有完備的上下文關係,原始碼給了一個例子:

Span A: the man went to the

Span B: to the store and bought

Span C: and bought a gallon of …

那麼對於bought來說,它在Span B和Span C中都有出現,那麼,哪一個上下文關係最全呢?其實我們憑直覺應該可以猜到應該是Span C,因為Span B中bought出現在句末,沒有下文。當然了,我們還是得用公式計算一下

score = min(num_left_context, num_right_context) + 0.01 * doc_span.length

score_B = min(4, 0) + 0.05 = 0.05

score_C = min(1,3) + 0.05 = 1.05

所以,在Span C中,bought的上下文語義最全,最終該方法會返回True or False,在滑動視窗這個方法中,一個詞很可能出現在多個span裡面,所以用這個方法判斷當前這個詞在當前span裡面是否具有最完整的上下文

回到上面,token_to_orig_map是用來記錄文章部分在all_doc_tokens的索引,而token_is_max_context是記錄文章每一個詞在當前span裡面是否具有最完整的上下文關係,因為一開始只有一個span,那麼一開始每個詞肯定都是True。split_token_index用於切分成每一個token,這樣可以進行上下文關係判斷,至於後面添[SEP]和segment_ids添1這種操作文字分類也有。

接下來將tokens(精細化切分後的)按照詞表轉化為id,另外若不足,則把0填充進去這種操作也是很常見的。

前面是進行判斷,如果切了之後答案並不在span裡面就直接捨棄,若在裡面,因為一開始all_doc_tokens裡面沒有問題和[CLS],[SEP]時正文的索引是tok_start_position,然後轉換為input_ids又有問題以及[CLS],[SEP],所以要得到正文索引需要跳過它們。

接下來大量的tf.logging只是寫入日誌資訊,同時也是你終端或輸出那裡看到的。

最終用這些引數例項化InputFeatures物件,然後不斷重複,每一個feature對應著一個特殊的id,即為unique_id。

模型構建

這裡大致與文字分類差不多,只是文字分類在模型裡面直接進行了softmax處理,然後進行最小交叉熵損失,而這次我們沒有直接這樣做,得到了開頭和結尾處的未歸一化的概率logits,之後我們直接返回。

然後這次我們是在model_fn_builder方法裡面的子方法model_fn裡定義compute_loss,其實這裡也是經過softmax進行歸一化,然後再計算交叉熵損失,最終返回均方誤差。

然後我們計算開頭和結尾處的損失,總損失為二者和的平均。

最終我們進行優化。

寫入預測

start_logit & end_logit 代表著未經過softmax的概率,start_logit表示tokens裡面以每一個token作為開頭的概率,後者類似的。還有一對null_start_logit & null_end_logit,它們兩個代表的是SQuAD2.0沒有答案的那些,預設全為0。

首先,簡單介紹一下_get_best_indexes,這個方法是用來輸出由高到低前n_best_size個的概率的索引。

遍歷start_indexes,end_indexes(都是分別經過_get_best_indexes得到),對於答案未缺失的,以具體的logit填入,另外,feature_index代表第幾個feature。

如果答案缺失,則全都為0

接下來我們進一步轉換為具體的文字

然後進一步清洗資料

這樣還有個問題,詞切分會自動小寫,與答案還存在一定的偏移,這裡介紹get_final_text方法來解決這一問題,比如:

pred_text = steve smith

orig_text = Steve Smith’s

這個方法通俗來講就是獲得orig_text(未經過詞切分)上正確的擷取片段。

然後將其新增到nbest中

同樣會存在沒有答案的情況

接下來會有一個total_scores,它的元素是start_logit和end_logit相加,注意,它們不是數值,是陣列,之後就計算total_scores的交叉熵損失作為概率。

剩下的部分跟文字分類差不多,這裡就此略過。


相關文章