背景
機構:Google Research 、U.C. Berkeley
作者:Nikita Kitaev、Łukasz Kaiser、Anselm Levskaya
論文地址:
https://www.aminer.cn/pub/5e5e189993d709897ce1ddbc
收錄會議:ICLR2020
論文程式碼:
https://github.com/google/trax/tree/master/trax/models/reformer
摘要
基於 Transformer 的各種巨型模型在各種自然語言處理任務中常常能夠取得最優結果,但這些模型的訓練成本往往過高,在針對長序列文字上尤甚。為此,本文提出兩種技術以改善基於 Transformer 的這類模型,名為 Reformer。第一,使用區域性敏感 hash,替換原始的點乘方式的 attention,從而將其空間複雜度從 O(L^2)降低到O(Llog L),其中L表示文字序列的長度。第二,使用逆殘差層代替標準的殘差,這使得訓練過程中只需儲存一次啟用值,而無需 N 次,其中 N 表示網路層數。最終的結果表明 Reformer 效能與 Transformer 相當,同時在長序列上具有更高的記憶體效率和更快的速度。
介紹
那訓練 Transformer 模型是否真需要很多資源且很低效?以現有的最大 Transformer 層為例,該 Transformer 層中引數量是 0.5B,這需要 2GB 的記憶體。(1M=1024KB,1KB=1024Byte。所以1GB=1024M=1024x1024KB=1024x1024x1024Byte=1073741824Byte。float 佔用 4 個 Byte。0.5B 即 5 億引數,需要的記憶體量為 5 億 *4 位元組=20 億位元組。這差不多是 1.86GB 即約為 2GB)對於由 64Ktokens 組成的序列,如果嵌入層的尺寸是 1024,batch size 是 8,那麼啟用值需要 64K * 1K * 8=0.5B 個浮點數來儲存,這又需要 2GB 的記憶體。如果每層的記憶體佔用只有上述提到的這些的話,那麼在單加速器上使用Transformer 處理 64K長度的序列也是輕而易舉。此外,如此前提下訓練 BERT 的整個語料庫也只需 17GB 的記憶體。然而,現實並非如此,真實環境下為何甚至不能在單臺機器上對這些模型進行微調呢?
這是因為上述僅僅考慮單層引數的記憶體佔用和輸入啟用值的記憶體消耗,而忽略了 Transformer 在記憶體佔用上的主要問題:
- 需要儲存啟用值用於反向傳播,那麼 N 層模型記憶體佔用是單層的 N 倍;
- 由於中間全連線層的深度 d_{ff} 通常遠大於注意力啟用層的深度 d_{model},而這需要佔用很大的記憶體;
- 長度為 L 的序列的 attention 的時間和空間複雜度是 O(L^2),那麼對於 64K tokens 的序列就會耗盡記憶體。
為此,本文提出 Reformer 模型以解決上述問題,具體採用如下方案:
- 可逆層(Reversible layer),在整個模型中只使用單個副本,可以消除層數因子 N。
- 前饋層(feed-forward layer)分開啟用和分塊處理,從而消除 d_{ff} 因子的影響,降低前饋層的記憶體佔用。
- 採用基於區域性敏感雜湊(locality-sensitive hashing,LSH)的近似注意力計算,讓注意力層的 O(L^2) 因子變為 O(L log L) ,這使得在長序列上的處理成為可能。
區域性敏感雜湊 ATTENTION
點乘 attention:
標準的 Transformer 使用點乘的 attention,queries 和 keys 的維度都是 d_k,values 的維度是 d_v。query 先與 key 做點乘,再除以根號 d_k,再輸入到 softmax 中得到 value 的權重,最後權重再與 value 相乘,得到最終的結果。在實際操作過程中是以矩陣方式進行批次操作,queries 組成矩陣 Q,keys 組成矩陣 K,values 組成矩陣 V,上述流程概況如下:
多頭 attention:
上述的 attention 操作並行地進行 h 次,再輸出維度為 d_v 的輸出結果。再將這些結果拼接,再做一次投射操作得到最終的結果。即所謂的多頭 attention。
高效記憶體 attention:
先來算下上述 attention 機制消耗的記憶體。假設 Q,K,V 的尺寸為 [batch_size,length,d_model]。QK^T 的尺寸為 [batch_size,length,length]。當 length=64k,即使 batch_size=1,那麼 64k*64k 大小的矩陣,如果用 32 位浮點數來儲存的話,需要 16GB 記憶體。鑑於此,在長序列上使用 Transformer 顯得不切實際。但是需要注意的是,QK^T 矩陣可以不必全部放在記憶體中,可以對每個 query 分別計算 attention。反向傳播計算梯度時再重新計算一次。這種方式計算 attention 雖然低效,但是所佔用的記憶體與 length 成正比。這種方法在本文這裡作為一種全 attention 的 baseline。
Q,K,V 從何處來?
上述討論了 Q、K、V,但是一般我們只會得到大小為 [batch_size,length,d_model] 的啟用值 A,這些值是 token 的嵌入所組成的句向量。那麼為了從 A 中得到Q、K、V,Transformer 使用了 3 個不同的線性層(引數不同)將 A 投射為 Q、K、V。對於使用區域性敏感雜湊 attention 的模型,我們希望 queries 和 keys(即 Q 和 K)相同。只需要 A 投射到 Q 和 A 投射到 K 時採用相同線性變換引數即可,而 A 投射到 V 時採用不同引數。這種方式成為共享 QK-Transformer。實驗表明共享 QK 並不會影響 Transformer 的效能,即使新增一項 d_k 的歸一化項。
Hashing attention:
在 LSH attention 中,假設 Q、K、V 的尺寸為 [batch_size,length,d_model],同時仍然使用此前介紹的多頭 attention 機制。那麼 QK^T 的尺寸為 [batch_size,length,length]。由於 softmax(QK^T) 的計算結果主要取決於值最大的部分,對於每個 query 只需關注 K 中與 query 最接近的點。當 K 的長度是 64k,那麼對個每個 query,本文僅僅考慮其最近的的 32 或 64 個 keys。如此會更加高效,那麼如何找尋最近的那些 keys 呢?
區域性敏感雜湊(LSH):
在高緯空間中找尋最近鄰可以使用區域性敏感雜湊(LSH)。將每個向量 x 透過 hash 函式h(x) 進行對映,如果近處的向量獲得相同的 hash,且具有高機率,而遠處的向量沒有,那麼這樣的 hash 稱為位置敏感型 hash。在此處例子中,我們實際上只要求近鄰的向量以高機率具有相同的 hash 值,並且 hash 桶也以高機率具有相同的大小。
具體是使用如 Figure 1 所示的隨機投射方法:
圖片來源:https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0
這裡有兩個點,它們投影到一個單位圓上,並以不同的角度隨機旋轉 3 次。可以觀察到,它們不太可能共享同一個 hash 桶。在後續例子中,可以看到兩個非常接近的點在3 次隨機旋轉後會位於相同的 hash 桶:
https://miro.medium.com/max/1052/1*aArg6a26KqbIlEkT43fxlw.gifAngular LSH 最近鄰搜尋的的一個簡化動畫:兩個點很接近的情況。
圖片來源:https://towardsdatascience.com/illustrating-the-reformer-393575ac6ba0
LSH attention:
其中 P_i 表示 query 在位置 i 所需要 attend 的集合,z 表示配分函式(partition function)比如 softmax 中的歸一化項。為了書寫清楚,這裡省略了縮放項根號 d_k。
對於批次操作,當遮蔽掉不在 P_i 中的元素,此時常規 attention 定義如下:
即對於不能 attend 到的位置,m(j, P_i) 為正無窮,那麼 q_i* k_j 減去正無窮再去 exp 操作,其結果為 0。這樣就不需要對於每個位置i都有單獨的 P_i。
在 LSH attention 中,query 中位置 i 所能夠 attend 的限制集合 P_i 被限制到一個 hash 桶中。Figure 2(a-b)展示的是全 attention 和 hash attention 的對比。
圖 b:計算 query 和 key 所歸屬的 hash 桶。再按照桶進行排序,同一個桶又按照原本的位置進行排序得到圖 b。可以看到,同一個桶,可以出現多個 query 但 keys 很少的情況,例如圖中藍色的桶 query 有 3 個,都 attend 到同一個 key 中。由於相似的 item 很有可能落在同一個桶裡,所以只在每個桶內部進行 attention 就可以近似全 attention。
圖 c:為了緩解桶中 q 和 k 不均衡問題,本文透過令 $k_{j}=\frac{q_{j}}{\left\|q_{j}\right\|}$ 使得 h(k_j)=h(q_j),即使用了 share-QK attention。然後先按照桶序號對 queries 排序,每個桶中,仍按照原本的 position 位置大小排序。得到圖 c。對比 b 圖和 c 圖可以看出,縱軸的 k 已經變成了 q。這時就能保證對角線都是 attend 到的而且 q 和 k 在桶中的個數一樣(因為 Q=K)。排序後的 attention 矩陣,相同桶的值會在對角線附近聚集。注意到圖中對角線的點為空心,這是因為雖然在正常情況下,q 會 attend to 本身位置的 value,但是在 share-QK 的實現下,如果 attend to 本身,會導致其值特別大,其他的值特別小,經過 softmax 之後,其他都是 0,就自己本身是 1。所以為了避免這種情況,q 不會去 attend 自身位置的值,除非只有自己本身可以 attend。
圖 d:即使 Q=K,還是會出現一個問題:有的桶中個數多,有的桶中個數少。比如一個極端情況,2 個桶,其中一個桶佔據了所有的 keys,另一個桶為空,那麼 LSH attention 就沒有起作用。於是在圖 c 的基礎上,增加了 chunk 的操作。對輸入進行排序之後(即圖 c 中先桶排序,同個桶內按照 token 的 position 排序)得到新的序列順序 s_i,比如圖中原來的序列順序是 [q_1,q_2,q_3,q_4,q_5,q_6],新的序列順序是[q_1,q_2,q_4,q_3,q_6,q_5] 。每個 chunk 內 query 的上限個數為 $m=\frac{2 l}{n_{\text {buckets}}}$, (l 為輸入 query 的長度) ,每個桶平均大小為 $m=\frac{l}{n_{\text {buckets}}}$,這裡假設桶中數量增加到均值兩倍的機率足夠低。對於桶中的每個 query,都可以 attend to 自己以及前一個桶中相同 hash 值的 key。
小結下,LSH attention 做了以下兩個事情:
第一,找到 Q、K 矩陣的 LSH hashes。
第二,在同一個 hash 桶內計算 k 和 q 向量的標準 attention。
更具體來說可分為以下 5 個步驟:
第一,令輸入序列 queries=keys
第二,做 LSH bucketing,即進行 hash 計算,得到每個 query 和 key 所歸屬的桶(不同顏色表示不同的桶)。
第三,根據桶編號對 query 進行排序,同個桶中,按照 query 原本的位置進行排序。
第四,對於排序後的新序列,進行 chunk 拆分
第五,對於每個 query 只 attend 自己以及自己之前的 chunk,對於這些候選集中相同桶的 key 進行 attend。
多輪 LSH attention:
LSH 有近似性,即不能保證相似的輸入能在同一個桶中。為了減輕這個問題,採用了 multi-round LSH attention。即重複上述過程多次,以使類似的 item 以儘可能高的機率落入相同的桶中,儘量避免相似 item 落入不同桶。更多的細節參考附件 A。
可逆層
如上所述,attention 的複雜度可以被減少為與序列長度成線性正比,但是,引數量佔的複雜度依舊很高,如何進一步減少呢?這裡就開始嘗試解決前文介紹部分所提到的第二和第三個問題,即大量的 encoder 和 decoder 層、全連線層 FFN 的深度問題。
Reversible residual Network (RevNet)
RevNet 的思想是每一層的 activations 可以根據下一層的 activations 推導獲得,從而不需要在記憶體中儲存 activations。在原本的 residual layer 中,由公式 y=x+F(x) 輸出得到 activations。其中 F 是 residual 函式。在 RevNet 中,先將輸入x分為兩個部分 x_1 和 x_2,然後透過不同 residual functions:F() 和 G() 得到輸出 y_1 和 y_2:
那麼如何在 Transformer 中引入 RevNet?將 attention layer 和 FFN layer 透過 ResNet 連線,從而減少記憶體的消耗。具體是令F函式為 attention 層,G 函式作為 FFN 層。需要注意的一點是 layer normalization 是包含在 residual blocks 中的。
Chunking
上述消除了 n_l 項的影響,深層的網路仍然佔有大量記憶體。在 FFN 中中間隱藏層的緯度通常非常大,比如 d_{ff}=4k 或者更大。由於 FFN 的計算與序列中的位置完全無關,因此計算可以被分割成 c 個塊,以降低記憶體的使用。雖然該操作其實可並行處理,但是每次只計算一個 chunk,透過時間換取記憶體空間。
實驗結果
對影像生成任務 imagenet64(序列長度為 12K)和文字任務 enwik8-64K(即序列長度為64K)進行了實驗,評價了可逆層、共享 query-key、LSH attention 對記憶體、精度和速度的影響。
可逆層和共享 query-key 的影響:
Figure 3 中的右部分驗證的是可逆層的影響。實驗中對比的可逆層和常規 Transformer 引數量相同,且學習曲線看起來也幾乎相同。這些結果表明,可逆 Transformer 在節省記憶體的同時並不會犧牲精度。
LSH attention 的影響:
如 Figure 4 所示,可以看出隨著 hash 數的增多精度也提升了。
更大的 Reformer 模型:
Figure 5 展示了不同層數的 Reformer 在 envik8 和 imagenet64 上的表現。下圖(左)是 Big Reformer 隨層數變化指標結果,20 層依然無壓力。而下圖(右)是普通 attention 和 LSH attention 在不同序列長度的速度比較,當序列很長的時候,LSH 具有顯著的優勢。
總結
Reformer 將 Transformer 的建模能力與能夠在長序列上高效執行的體系結構相結合,使其即使處理大模型時,也可以使用較小的記憶體。這將有助於大型、海量引數化的 Transformer 模型變得更廣泛可用。此外,處理長序列的能力為 Reformer 在許多生成任務上的使用開闢了道路。除了生成非常長的連貫文字外,Reformer 可以把 Transformer 模型的能力應用到其他領域,如時間序列預測、音樂、影像等。
作者:劉傑鵬(微訊號:onepieceand)
畢業院校:華中科技大學
研究方向:機器閱讀理解、文字生成等。