1. Motivation
在Transformer-XL中,由於設計了segments,如果仍採用transformer模型中的絕對位置編碼的話,將不能區分處不同segments內同樣相對位置的詞的先後順序。
比如對於$segment_i$的第k個token,和$segment_j$的第k個token的絕對位置編碼是完全相同的。
鑑於這樣的問題,transformer-XL中採用了相對位置編碼。
2. Relative Positional Encodings
paper中,由對絕對位置編碼變換推匯出新的相對位置編碼方式。
vanilla Transformer中的絕對位置編碼
它對每個index的token都通過sin/cos變換,為其唯一指定了一個位置編碼。該位置編碼將與input的embedding求sum之後作為transformer的input。
那麼如果將該位置編碼應用在transformer-xl會怎樣呢?
其中$\tau$表示第$\tau$個segment, 是當前segment的序列$s_{\tau}$的word embedding sequence, $L$是序列長,$d$是每個word embedding的維度。$U_{1:L}$表示該segment中每個token的絕對位置編碼組成的序列。
可以看到對於$h_{\tau + 1}$和$h_{\tau}$,其在位置編碼表示是完全相同的,都是$U_{1:L}$,這樣就會造成motivation中所述的無法區分在不同segments中相對位置相同的tokens.
3. Transformer-XL中的相對位置編碼
transformer-xl中沒有采用vanilla transformer中的將位置編碼靜態地與embedding結合的方式;而是沿用了shaw et al.2018的相對位置編碼中通過將位置資訊注入到求Attention score的過程中,即將相對位置資訊編碼入hidden state中。
為什麼要這麼做呢?paper中給出的解釋是:
1) 位置編碼在概念上講,是為模型提供了時間線索或者說是關於如何收集資訊的"bias"。出於同樣的目的,除了可以在初始的embedding中加入這樣的統計上的bias, 也可以在計算每層的Attention score時加入同樣的資訊。
2) 以相對而非絕對的方式定義時間偏差更為直觀和通用。比如對於一個query vector $q_{\tau,i}$ 與 key vectors $k_{\tau, \leq i}$做attention時,這個query 並不需要知道每一個key vector在序列中的絕對的位置來決定segment的時序。它只需要知道每一對$k_{\tau,j}$ 和其本身$q_{\tau,i}$的相對距離(比如,i - j)就足夠。
因此,在實際中可以建立一個相對位置編碼的encodings矩陣 $R \in \mathbb{R} ^ {L_{max} \times d}$,其中第i行 $R_i$表示兩個pos(比如位置pos_q, pos_k)之間的相對距離為i. (可以參考我在參考連結3中的介紹,以下圖示便是一個簡單的說明例子.
但是圖示中的i表示query的位置pos, 與$R_i$ 中的i不同。如果以該圖示為例,當pos_q = i, pos_k = i - 4時, 相對位置為 0, 二者的相對位置編碼是 $R_0$。
--------------------------------------------------------------------------------------------------
Transformer-XL的相對位置編碼方式是對Shaw et al.,2018 和 Huang et al.2018提出模型的改進。它由採用絕對編碼計算Attention score的表示式出發,進行了改進3項改變。
若採用絕對位置編碼,hidden state的表示式為:
,
那麼對應的query,key的attention score表示式為:
(應用乘法分配率, query的embedding 分別與 key的embedding, positional encoding相乘相加;之後 query的positional encoding分別與 key的embedding, positional encoding相乘相加)
(其中i是query的位置index,j是key的位置index) (WE, WU是對embedding進行linear projection的表示,細節內容可以參看attention is all you need 中對multi-head attention的介紹)
,
Transformer-XL 對上式進行了改進:
改進1) $Uj \rightarrow R_{i - j}$.
首先將 $A_{i, j} ^ {abs}$ 中的key vector的絕對位置編碼 $U_j$ 替換為了相對位置編碼 $R_{i - j}$ 其中 $R$是一個沒有需要學習的引數的sinusoid encoding matrix,如同Vaswani et al., 2017提出的一樣。
該改進既可以避免不同segments之間由於tokens在各自segment的index相同而產生的時序衝突的問題。
改進2) $(c) : U_i^{T} W_q ^ {T} \rightarrow {\color{red} u} \in \mathbb{R}^d$;$(d) : U_i^{T} W_q ^ {T} \rightarrow {\color{red} v} \in \mathbb{R}^d$
在改進1中將key的絕對位置編碼轉換為相對位置編碼,在改進2中則對query的絕對位置編碼進行了替換。因為無論query在序列中的絕對位置如何,其相對於自身的相對位置都是一樣的。這說明attention bias的計算與query在序列中的絕對位置無關,應當保持不變. 所以這裡將$A_{i, j} ^ {abs}$ 中的c,d項中的$U_i^{T} W_q ^ {T}$分別用一個可學習引數$u \in \mathbb{R}^d$,$v \in \mathbb{R}^d$替換。
改進3) $W_{k} \rightarrow W_{k, E}$, $W_{k, R}$
在vanilla transformer模型中,對query, key分別進行線性對映時,query 對應$W_q$矩陣,key對應$W_k$矩陣,由於input 是 embedding 與 positional encoding的相加,也就相當於
$query_{embedding} W_q + query_{pos encoding} W_q$得到query的線性對映後的表徵;
$key_{embedding} W_q + key_{pos encoding} W_q$ 得到key的線性對映後的表徵。
可以看出,在vanilla transformer中對於embedding和positional encoding都是採用的同樣的線性變換。
在改進3中,則將key的embedding和positional encoding 分別採用了不同的線性變換。其中$W_{k,E}$對應於key的embedding線性對映矩陣,$W_{k,R}$對應與key的positional encoding的線性對映矩陣。
在這樣的引數化定義後,每一項都有了一個直觀上的表徵含義,(a)表示基於內容content的表徵,(b)表示基於content的位置偏置,(c)表示全域性的content的偏置,(d)表示全域性的位置偏置。
與shaw的RPR的對比
shaw的RPR可以參考我在參考連結3中的介紹。這裡給出論文中的表示式:其中$a_{i,j}$是query i, key j的相對位置編碼矩陣$A$中的對應編碼。
attention score: (在key的表徵中加入相對位置資訊)
softmax計算權值係數:
attention score * (value + 的output:(在value的表徵中加入相對位置資訊)
1) 對於$e_{ij}$可以用乘法分配率拆解來看,那麼其相當於transforerm-xl中的(a)(b)兩項。也就是在shaw的模型中未考慮加入(c)(d)項的全域性內容偏置和全域性位置偏置。
2) 還是拆解$e_{ij}$來看,涉及到一項為$x_iW^Q(a_{ij}^K)^T$,是直接用 query的線性對映後的表徵 與 相對位置編碼相乘;而在transformer-xl中,則是與query的線性對映後的表徵 與 相對位置編碼也進行線性對映後的表徵 相乘。
優勢:
paper中指出,shaw et al用單一的相對位置編碼矩陣 與 transformer-xl中的$W_kR$相比,丟失掉了在原始的 sinusoid positional encoding (Vaswani et al., 2017)中的歸納偏置。而XL中的這種表徵方式則可以更好地利用sinusoid 的inductive bias。
----------------------------為什麼XL中的這種表徵方式則可以更好地利用sinusoid 的inductive bias?--------------------------------------------------------------------
有幾個問題:原始的 sinusoid positional encoding (Vaswani et al., 2017)中的歸納偏置是什麼呢?為什麼shaw et al 把它丟失了呢?為什麼transformer-xl可以適用呢?
這裡需要搞清楚:
1. 為什麼在vanilla transformer中使用sinusoid?
2. shaw et al.2018中的相對位置編碼Tensor是什麼?
3. transformer-xl的相對位置編碼矩陣是什麼?
對於1,sinusoid函式具有並不受限於序列長度仍可以較好表示位置資訊的特點。
We chose the sinusoidal version because it may allow the model to extrapolate to sequence lengths longer than the ones encountered during training. ~Attention is all you need.
為什麼不用學得引數而採用sinusoid函式呢?sinusoidal函式並不受限於序列長度,其可以在遇到訓練集中未出現過的序列長度時仍能很好的“extrapolate.” (外推),這體現了其具有一些inductive bias。
對於2,shaw et al.2018中的相對位置編碼Tensor是兩個需要引數學習的tensor.
相對位置編碼矩陣是設定長度為 2K + 1的(K是視窗大小) ,維度為$d_a$的2個tensor(分別對應與key的RPR和value的RPR),其第i行表示相對距離為i的query,key(或是query, value)的相對位置編碼。這兩個tensor的引數都是需要訓練學習的。那麼顯然其是受限於最大長度的。在RPR中規定了截斷的視窗大小,在遇到超出視窗大小的情況時,由於直接被截斷而可能丟失資訊。
對於3,transformer-xl的相對位置編碼矩陣是一個sinusoid矩陣,不需要引數學習。
在transformer-xl中雖然也是引入了相對位置編碼矩陣,但是這個矩陣不同於shaw et al.2018。該矩陣$R_{i,j}$是一個sinusoid encoding 的矩陣(sinusoid 是借鑑的vanilla transformer中的),不涉及引數的學習。
具體實現可以參看程式碼,這裡展示了pytorch版本的位置編碼的程式碼:
1 class PositionalEmbedding(nn.Module): 2 def __init__(self, demb): 3 super(PositionalEmbedding, self).__init__() 4 5 self.demb = demb 6 7 inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 8 self.register_buffer('inv_freq', inv_freq) 9 10 def forward(self, pos_seq, bsz=None): 11 sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 12 pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 13 14 if bsz is not None: 15 return pos_emb[:,None,:].expand(-1, bsz, -1) 16 else: 17 return pos_emb[:,None,:]
其中$demb$是embedding的維度。
sinusoid的shape:[batch_size, seq_length × (d_emb / 2)]
sin,cos concat之後,pos_emb的shape:[batch_size, seq_length × d_emb]
pos_emb[:,None,:]之後的shape:[batch_size, 1, seq_length × d_emb]
那麼綜合起來看,transformer-xl的模型的hidden states表示式為:
4. 高效計算方法
在該表示式中,在計算$W_{k,R}R_{i-j}$時,需要對每一對(i,j)進行計算,時間複雜度是$O(n^2)$。paper中提出了高效的計算方法,使其降為$O(n).$
核心演算法:發現(b)項組成的矩陣的行列之間的關係,構建一個矩陣,將其按行左移,恰好是(b)項矩陣$B$,而所構建的矩陣只需要$O(n)$時間。
由於相對距離(i-j)的變化範圍是[0, M + L - 1] (其中M是memory的長度,L是當前segment的長度)
那麼令:
那麼將(b)項應用與所有的(i,j)可得一個$L \times (M + L)$的矩陣 $B$: (其中q是對E經過$W_q$對映變換後的表示)
看這些帶紅線的部分,是不是隻有q的下標不一樣!
如果我們定義$\widetilde{B}$:
對比$B$與$\widetilde{B}$發現,將$\widetilde{B}$的第i行左移 $L - 1 - i$個單位即為$B$。而$\widetilde{B}$的計算僅涉及到兩個矩陣的相乘,因此$B$的計算也僅需要求$qQ^T$之後按行左移即可得到,時間複雜度降為$O(n)$!
同理,可以求(d)項的矩陣D。
這樣將B,D原本需要$O(n^2)$的複雜度,降為了$O(n)$.
5. 總結
Transformer-XL針對其需要對segment中相對位置的token加入位置資訊的特點,將vanilla transformer中的絕對位置編碼方式,改進為相對位置編碼。改進中涉及到位置編碼矩陣的替換、query全域性向量替換、以及為key的相對位置編碼和embedding分別採用了不同的線性對映矩陣W。
transformer-xl與shaw et al.2018的相對編碼方式亦有區別。1. shaw et al.2018的相對編碼矩陣是一個需要學習引數的tensor,受限於相對距離的視窗長度設定;而transformer-xl的相對編碼矩陣是一個無需引數學習的使用sinusoid表示的矩陣,可以更好的generalize到訓練集中未出現長度的長序列中;2. 相比與shaw et al.2018,transformer-xl的attention score中引入了基於content的bias,和基於位置的bias。
另外在計算優化上,transformer-xl提出了一種高效計算(b)(d)矩陣運算的方法。通過構造可以在$O(n)$時間內計算的新矩陣,並將其項左移構建出目標矩陣B,D的計算方式,將時間複雜度由$O(n^2)$降為$O(n)$。
參考:
1. Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context: https://arxiv.org/pdf/1901.02860.pdf
2. Self-Attention with Relative Position Representations (shaw et al.2018): https://arxiv.org/pdf/1803.02155.pdf
3. [NLP] 相對位置編碼(一) Relative Position Representatitons (RPR) - Transformer https://www.cnblogs.com/shiyublog/p/11185625.html
[支付寶] 感謝您的捐贈!
That's been one of my mantras - focus and simplicity. Simple can be harder than complex: you have to work hard to get your thinking clean to make it simple. But it's worth it in the end beacuse once you get there, you can move mountains. ~ Steve Jobs