Transformer的七十二變

PaperWeekly發表於2020-04-01

自 2017 年 Google 提出 Transformer 後,其在各項 NLP 任務中都取得了 SOTA 的表現。然而其自身的結構缺陷導致了兩個問題:

1)由於自注意力機制每次都要計算所有詞之間的注意力,其所需計算複雜度為輸入長度的平方;2)Transformer 需要事先設定輸入長度,這導致了其對於長程關係的捕捉有了一定限制,並且由於需要對輸入文件進行分割會導致語意上的碎片化。

近年來許多工作通過對 Transformer 結構的調整優化來緩解以上兩個問題。

本文分為兩部分,第一部分介紹和比較的三個模型(Star-Transformer 和 BP-Transformer)試圖在時間複雜度和空間複雜度上優化 Transformer。第二部分介紹和比較的兩個模型(Transformer-XL 和 Compressivetransformer)試圖解決上面提出的第二個問題。

對 Transformer 不瞭解的可先閱讀該部落格:

https://jalammar.github.io/illustrated-transformer/

一、更高效的Transformer

1. Star-Transformer

Transformer的七十二變

論文標題:Star-Transformer

論文來源:NAACL 2019

論文連結:https://arxiv.org/abs/1902.09113

程式碼連結:https://github.com/fastnlp/fastNLP

原始的 Transformer 在計算注意力的時候,序列中每個元素要和所有元素進行計算,也是這樣的計算方式導致了其複雜度為序列長度的平方。

同時 Transformer 這樣所有元素直接相互作用的計算方式沒能夠很好地使用我們所知道的一些語言序列上的特性,比如語言序列中相鄰的詞往往本身就會有較強的相關性。

對於這個問題,Star-Transformer 在注意力機制的計算上進行了優化,構建了一個星狀的結構,所有序列中直接相鄰的元素可以直接相互作用,而非直接相鄰的元素則通過中心元素實現間接得資訊傳遞。

具體結構比較如下圖所示,左邊為正常的 Transformer,右邊為 Star-Transformer。

Transformer的七十二變

下圖為 Star-Transformer 的引數更新演算法。在初始化階段,衛星節點 的初始值為相應的詞向量,而中心節點 的初始值為所有衛星節點詞向量的平均值。

演算法中引數更新分為兩步:第一步為衛星節點的更新,第二步為中心節點的更新。兩步的更新都是基於多頭注意力機制。

對於衛星節點,計算多頭注意力機制時只需考慮該節點狀態與直接相鄰節點,中心節點,該節點詞向量和本節點上一時刻狀態的資訊互動(如下圖中Transformer的七十二變 )。

因為中心節點擔負著所有衛星節點之間的資訊互動,因此中心節點在更新時須與自己上一時刻的資訊和所有衛星節點進行資訊互動。同時為了表示位置資訊,在衛星節點中還必須拼接上表示位置資訊的可學習的向量。

該模型在使用中,針對序列的下游任務使用衛星節點的輸出,而針對語言推理文字分類這種需要整個句子的任務則可以使用中心節點的輸出。

作者的實驗中表明,該非直接的聯絡方式同樣能夠學習到長程聯絡,同時在一些任務上的也取得了比 Transformer 更好的表現。

Transformer的七十二變

2. BP-TransformerTransformer的七十二變

論文標題:BP-Transformer: Modelling Long-Range Context via Binary Partitioning

論文來源:NAACL 2019

論文連結:https://arxiv.org/abs/1911.04070

程式碼連結:https://github.com/yzh119/BPT

Transformer的七十二變

Transformer的七十二變

Transformer的七十二變

構建完整個圖後,該模型可通過以下演算法更新引數:

Transformer的七十二變

其中 GSA (Graph Self-Attention) 為:

Transformer的七十二變

加入相對位置後,注意力的計算可修正為以下公式:

Transformer的七十二變

A(u) 為所有與 u 節點想連的節點,由上面公式可見 GSA 其實就是多頭注意力機制,只是相比原始 Transformer 計算一個節點與所有節點的注意力,這裡只計算節點與其相鄰節點的注意力,而因為在二叉樹中有跨層次的節點連線即有自節點元素和中間節點元素(片段)的連線,就實現在計算不同粒度下的注意力。

該模型在初始化時,葉子節點初始化為相應的詞向量,而片段節點則初始化為零。在針對像語言模型這種序列型的下游任務中,可使用葉子節點的輸出,而針對像文字分類等需要用的整個句子的則使用二叉樹根節點的輸出。

作者在多個任務中測試,結果表明相比原始的注意力計算方式,該模型在長文字任務中取得了更好的表現。

二、學習更長語義聯絡的Transformer

1. Transformer-XL

Transformer的七十二變

論文標題:Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

論文來源:ACL 2019

論文連結:https://arxiv.org/abs/1901.02860

程式碼連結:https://github.com/kimiyoung/transformer-xl

相比原始 Transformer,Transformer-XL 有以下兩個變化:1)引入迴圈機制,使得新模型能夠學習到更長的語義聯絡;2)拋棄絕對位置表示,採用相對位置表示。

1.1 迴圈機制

在原始 Transformer 中,每個序列的計算相互獨立,因此也導致了其只能夠學習到同個序列內的語義聯絡。而在 Transformer-XL 中,每個序列計算後的隱狀態會參與到下一個序列的計算當中,使得模型能夠學習到跨序列的語義聯絡。

如下圖所示,左邊為原始 Transformer,右邊為 Transformer-XL。

Transformer的七十二變

相比原始 Transformer,Transformer-XL 模型的計算當中加入綠色連線,使得當層的輸入取決於本序列和上一個序列前一層的輸出。具體計算公式如下:

Transformer的七十二變

其中 h 為隱藏層,n 為層數, τ 表示序列數,W 為模型引數,° 表示矩陣拼接。SG 意為 stop-gradient,即停止梯度計算,這樣雖然在計算中運用了前一個序列的計算結果,但是在反向傳播中並不對其進行梯度的更新。

式子一:將上一序列上一層隱狀態與本序列上一層隱狀態進行矩陣拼接,這也是 Transformer-XL 實現迴圈機制的關鍵。

式子二:計算注意力機制所需的 q,k,v。與原始 Transformer 不同的是 k,v 的計算是取決於由式一得到的隱狀態,而 q 則是隻含有本序列的資訊。在注意力的計算中,q 與 k,v 的相互作用讓模型實現了跨序列的語義學習。

式子三:常規的 Transformer 層計算。

Transformer-XL 通過引入跨層的迴圈機制,使得模型能夠學習到跨序列的語義資訊。這樣跨層的方式也使得其能夠學習到的語義長度受限於網路深度,具體依賴關係為 N*(L-1) 用大 O 表示可近似為 O(N*L),N 為網路深度,L 為序列長度。如下圖所示,序列長度為 4,網路深度為 3。

Transformer的七十二變

1.2 相對位置編碼

由於注意力機制忽視了位置資訊,因此在 Transformer 中需要加入位置編碼。原始 Transformer 採用了正弦/餘弦函式來編碼絕對位置資訊。然而在 Transformer-XL 中,若採用和 Transformer 一樣的絕對位置編碼,那麼不同序列間同個位置會得到同樣的編碼。

因此這種方法在 Transformer-XL 中行不通,為了解決這個問題 Transformer-XL 採用了相對位置編碼。

以下公式和分別為原始 Transformer 和 Transformer-XL 中注意力的計算公式。在其中 E 表示詞的 Embedding,而 U 表示絕對位置編碼。在中 R 為相對位置表示,該相對位置表示也是一個正弦函式表示。

相比,除了用相對位置表示 R 替代了絕對位置表示 U 後,還用兩個可學習引數 u 和 v 替代了中的 query 位置的對映,同時將原本對 key 的對映矩陣分成兩組矩陣和,分別生成基於內容的 key 向量和基於位置的 key 向量。

替換後中四項分別代表:(a) 基於內容的定址;(b) 基於內容的位置偏差;(c) 全部內容偏差;(d) 全域性位置偏差。

Transformer的七十二變

採用相對位置編碼後,Transformer-XL 具體的計算公式如下:

Transformer的七十二變

2. Compressive Transformer

Transformer的七十二變

論文標題:Compressive Transformers for Long-Range Sequence Modelling

論文來源:ICLR 2020

論文連結:https://arxiv.org/abs/1911.05507

為了增加 Transformer 可以學習到的語義長度,Compressiv Transformer 在原 Transformer 的結構上增加了一個記憶模組和一個壓縮記憶模組。

每一個序列計算後其隱狀態會被放入記憶模組中,然後記憶模組中的部分原有記憶會被壓縮然後放入壓縮記憶模組中,這時壓縮記憶模組中的部分記憶則會被拋棄掉。

如下圖所示,壓縮記憶模組和記憶模組維度皆為 6,而序列長度為 3。箭頭和f表示對記憶模組中的記憶進行壓縮並放入壓縮記憶模組中。

Transformer的七十二變

Compressive Transformer 具體的演算法細節如下,其中m表示記憶模組,cm 表示壓縮記憶模組,h 為隱狀態,d 為 Embedding 維度,為壓縮記憶模組長度,為記憶模組長度,c 為壓縮常數,l 為層數。

Transformer的七十二變

下圖為一個簡易示意圖,紅色表示計算注意力,藍色表示將計算過的序列存入記憶模組和壓縮記憶模組過程。

Transformer的七十二變

在論文中作者嘗試瞭如下幾個不同的壓縮函式:1)max/mean pooling;2)1Dconvolution;3)dialated convolutions;4)most-used。實驗表明在 WIKITEXT-103 資料集中 1D convolution 表現最好。

同時為了更好的學習壓縮函式的引數,模型訓練時使用了一個輔助的損失函式(因為若是依賴模型的損失函式,則梯度需要經過很長的時序才能傳到存貯的老的記憶,類似於 RNN 裡梯隊消失問題)。

該損失函式為注意力重建損失函式,旨在測量通過更新後的記憶計算的注意力和使用原本記憶計算的注意力之間的差距。通過最小化該差距來確保有效的壓縮資訊。

Transformer的七十二變

通過引入記憶模組後,Compressive Transformer 能夠捕捉的語義長度為 O(L*(+c) 其中為壓縮記憶模組長度,為記憶模組長度,c 為壓縮常數。

相比較 Transformer-XL 的 O(LN),Compressive Transformer 通過將計算後的序列儲存在記憶模組中有效的提高了模型捕捉長程語義的能力。

Reference:

BP-Transformer: Modelling Long-Range Context via Binary Partitioning.Zihao Ye, Qipeng Guo, Quan Gan, Xipeng Qiu, Zheng Zhang

Star-Transformer.Qipeng Guo, Xipeng Qiu, Pengfei Liu, Yunfan Shao, Xiangyang Xue, Zheng Zhang

COMPRESSIVE TRANSFORMERS FOR LONG-RANGE SEQUENCE MODELLING, Jack W. Rae Anna Potapenko Siddhant M. Jayakumar Chloe Hillier Timothy P. Lillicrap

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context, Zihang Dai∗12, Zhilin Yang∗12, Yiming Yang1, Jaime Carbonell1, Quoc V. Le2, Ruslan Salakhutdinov1

相關文章