EMNLP’19-Mask-Predict: Parallel Decoding of Conditional Masked Language Models

soccqy發表於2020-10-12

Intorduction

大多數機器翻譯系統使用順序譯瑪的策略,其中單詞是一個一個預測的。本文展示了一個並行譯碼的模型,該模型在恆定譯碼迭代次數下得到翻譯結果。本文提出的條件掩碼語言模型(CMLMS
解碼器的輸入是一個完全被masked的句子,並行的預測所有的單詞,並在恆定數量的遮蔽-預測迴圈之後結束。這種整體策略使模型可以在豐富的雙向上下文中反覆重新考慮單詞的選擇,並且正如我們將要展示的那樣,它僅在幾個週期內就可以產生高質量的翻譯。 M a s k − p r e d i c t Mask-predict Maskpredict反覆掩蓋並重新預測模型對當前轉換最不滿意的單詞子集。

Conditional Masked Language Models

  • Y Y Y:目標語句
  • X X X:源語句
  • Y o b s , Y m a s k Y_{obs},Y_{mask} YobsYmask:將目標語句劃分為兩類。

C M L M CMLM CMLM根據 X 與 Y o b s X與Y_{obs} XYobs預測 Y m a s k Y_{mask} Ymask

Architecture

E n c o d e r Encoder Encoder:基於自注意機制對原文字進行編碼。
D e c o d e r Decoder Decoder:目標語言的譯碼器,具有面向編碼器輸出的自注意機制,以及另外一組面向目標文字的自注意力機制。作者通過刪除自注意mask機制來改進標準解碼器。

Training Objective

先對目標語句隨機的選擇 Y m a s k Y_{mask} Ymask,被mask的token的數量遵循正態分佈。之後選中的token被一個特殊的 M A S K    MASK\; MASKtoken來代替。作者利用交叉熵來優化模型。並且,儘管譯碼器的輸出是整個目標語句,但只對 Y m a s k Y_{mask} Ymask執行交叉熵損失函式。

Predicting Target Sequence Length

在非自迴歸機器翻譯中,通常將整個編碼器的輸出作為一個目標語句長度預測模型的輸入來得到目標語句的長度。本文中作者,直接將 L E N G T H LENGTH LENGTH作為一個輸入token輸入編碼器,利用編碼器來預測目標語句的長度,即編碼器的一個輸出為目標序列的長度 N N N,並利用交叉熵損失來訓練。

Decoding with Mask-Predict

M a s k − P r e d i c t Mask-Predict MaskPredict:首先,先選擇若干token進行mask,然後用decoder去預測它們,將輸出中,預測概率值小的token再次mask,並輸入decoder中再次預測目標序列。

Formal Description

  • 根據Encoder預測出來的目標序列長度 N N N,作者定義了兩個變數 ( y 1 , . . . , y N ) (y_{1},...,y_{N}) (y1,...,yN)以及 ( p 1 , . . . , p N ) (p_{1},...,p_{N}) (p1,...,pN)。這個過程將進行T個迴圈(T可以是一個常數或者序列長度N的函式),並且在每次迭代過程中,都會執行Mask操作,然後是預測目標序列。

  • M a s k \mathbf{Mask} Mask:在第一次迭代時,作者將N個token全部mask,之後的迭代過程中,作者只mask掉預測概率值最低的n(n是迭代次數的函式即 n = N ⋅ T − t T n=N\cdot{\frac {T-t}{T}} n=NTTt)個token。即:
    在這裡插入圖片描述

  • P r e d i c t \mathbf{Predict} Predict:在得到 Y m a s k ( t ) Y_{mask}^{\left(t\right)} Ymask(t)後,CMLM將根據原文字 X X X Y o b s ( t ) Y_{obs}^{\left(t\right)} Yobs(t)來預測被mask掉的token。

舉例介紹:
在這裡插入圖片描述

  • 首先是序列長度預測:作者選擇 l l l個序列長度,平行計算。
  • 給定如上的句子,首先將全部mask的序列輸入譯碼器中,如下:
    在這裡插入圖片描述
    接下來作者選擇這12個token中預測概率值最小的八個,將他們mask掉,在第二次迭代中重新預測。第二次迭代的輸出如上,再將預測概率值最低的四個token執行mask,並再次預測。最終經過繼續的迭代的到最終的輸出,作者對比 l l l個輸出,將概率值對大的序列作為最終輸出:
    在這裡插入圖片描述

相關文章