背景
BERT、XLNet、RoBERTa等基於Transformer[1]的預訓練模型推出後,自然語言理解任務都獲得了大幅提升。問答任務(Question Answering,QA)[2]也同樣取得了很大的進步。
用BERT類模型來做問答或閱讀理解任務,通常需要將問題和問題相關文件拼接一起作為輸入文字,然後用自注意力機制對輸入文字進行多層互動編碼,之後用線性分類器判別文件中可能的答案序列。如下圖:
雖然這種片段拼接的輸入方式可以讓自注意力機制對全部的token進行互動,得到的文件表示是問題相關的(反之亦然),但相關文件往往很長,token數量一般可達問題文字的10~20倍[3],這樣就造成了大量的計算。
在實際場景下,考慮到裝置的運算速度和記憶體大小,往往會對模型進行壓縮,比如透過蒸餾(distillation)小模型、剪枝(pruning)、量化(quantization)和低軼近似/權重共享等方法。
但模型壓縮還是會帶來一定的精度損失。因此我們思考,是不是可以參考雙塔模型的結構,提前進行一些計算,從而提升模型的推理速度?
如果這種思路可行,會有幾個很大的優勢:
它不需要大幅修改原來的模型架構 也不需要重新預訓練,可以繼續使用標準Transformer初始化+目標資料集fine-tune的精調方式 還可以疊加模型壓縮技術
經過不斷地嘗試,我們提出了《Deformer:Decomposing Pre-trained Transformers for Faster Question Answering》[4],在小幅修改模型架構且不更換預訓練模型的情況下提升推理速度。下面將為大家介紹我們的思考歷程。
論文連結:
https://awk.ai/assets/deformer.pdf
程式碼連結:
https://github.com/StonyBrookNLP/deformer
模型結構
在開篇的介紹中,我們指出了QA任務的計算瓶頸主要在於自注意力機制需要互動編碼的token太多了。因此我們猜想,是否能讓文件和問題在編碼階段儘可能地獨立?
這樣的話,就可以提前將最難計算的文件編碼算好,只需要實時編碼較短的問題文字,從而加速整個QA過程。
部分研究表明,Transformer 的低層(lower layers)編碼主要關注一些區域性的語言表層特徵(詞形、語法等等),到高層(upper layers)才開始逐漸編碼與下游任務相關的全域性語義資訊。因此我們猜想,至少在模型的某些部分,“文件編碼能夠不依賴於問題”的假設是成立的。 具體來說可以在 Transformer 開始的低層分別對問題和文件各自編碼,然後再在高層部分拼接問題和文件的表徵進行互動編碼,如圖所示:
為了驗證上述猜想,我們設計了一個實驗,測量文件在和不同問題互動時編碼的變化程度。下圖為各層輸出的文件向量和它們中心點cosine距離的方差:
可以看到,對於BERT-Based的QA模型,如果編碼的文件不變而問題變化,模型的低層表徵往往變化不大。這意味著並非所有Transformer編碼層都需要對整個輸入文字的全部token序列進行自注意力互動。
因此,我們提出Transformer模型的一種變形計算方式(稱作 DeFormer):在前層對文件編碼離線計算得到第 層表徵,問題的第層表徵透過實時計算,然後拼接問題和文件的表徵輸入到後面到層。下面這幅圖示意了DeFormer的計算過程:
值得一提的是,這種方式在有些QA任務(比如SQuAD)上有較大的精度損失,所以我們新增了兩個蒸餾損失項,目的是最小化Deformer的高層表徵和分類層logits與原始BERT模型的差異,這樣能控制精度損失在1個點左右。
實驗
這裡簡要描述下四組關鍵的實驗結果:
(1)在三個QA任務上,BERT和XLNet採用DeFormer分解後,取得了2.7-3.5倍的加速,節省記憶體65.8-72.0%,效果損失只有0.6-1.8%。BERT-base()在SQuAD上,設定能加快推理3.2倍,節省記憶體70%。
(2)實測了原模型和DeFormer在三種不同硬體上的推理延遲。DeFormer均達到3倍以上的加速。
(3)消融實驗證明,新增的兩個蒸餾損失項能起到彌補精度損失的效果。
(4)測試DeFormer分解的層數(對應折線圖橫軸)對推理加速比和效能損失的影響。這個實驗在SQuAD上進行,且沒有使用蒸餾trick。
總結
這篇文章提主要提出了一種變形的計算方式DeFormer,使問題和文件編碼在低層獨立編碼再在高層互動,從而使得可以離線計算文件編碼來加速QA推理和節省記憶體。
創新之處在於它對原始模型並沒有太大修改。部署簡單,而效果顯著。 實驗結果表明基於BERT和XLNet的Deformer均能取得很好的表現。筆者推測對其他的Transformer模型應該也同樣有效,並且其他模型壓縮方法和技術應該也可以疊加使用到DeFormer上來進一步加速模型推理。