結合RNN與Transformer雙重優點,深度解析大語言模型RWKV

华为云开发者联盟發表於2024-07-05

本文分享自華為雲社群《【雲駐共創】昇思MindSpore技術公開課 RWKV 模型架構深度解析》,作者:Freedom123。

一、前言

Transformer模型作為一種革命性的神經網路架構,於2017年由Vaswani等人 提出,並在諸多工中取得了顯著的成功。Transformer的核心思想是自注意力機制,透過全域性建模和平行計算,極大地提高了模型對長距離依賴關係的建模能力。但是Transformer在處理長序列時面臨記憶體和計算複雜度的問題,因為其複雜度與序列長度呈二次關係一直未業內人員所詬病。今天我們學習的RWKV,它作為對Transformers模型的替代,正在引起越來越多的開發人員的關注。RWKV模型以簡單、高效、可解釋性強等特點,成為自然語言處理領域的新寵。下面讓我們一起來學習RWKV模型。

二、RWKV簡介

RWKV(Receptance Weighted Key Value)是一個結合了RNN與Transformer雙重優點的模型架構,由香港大學物理系畢業的彭博首次提出。其名稱源於其 Time-mix 和 Channel-mix 層中使用的四個主要模型元素:R(Receptance):用於接收以往資訊;W(Weight):是位置權重衰減向量,是可訓練的模型引數; K(Key):是類似於傳統注意力中 K 的向量; V(Value):是類似於傳統注意力中 V 的向量。

RWKV模型作為一種革新性的大型語言模型,結合了RNN的線性複雜度和Transformer的並行處理優勢,引入了Token shift和Channel Mix機制來最佳化位置編碼和多頭注意力機制,解決了傳統Transformer模型在處理長序列時的計算複雜度問題。RWKV在多語言處理、小說寫作、長期記憶保持等方面表現出色,可以主要應用於自然語言處理任務,例如文字分類、命名實體識別、情感分析等。

cke_138.png

三、 RWKV模型的演進

RWKV模型之所以發展到今天的結構經歷了五個階段,從RNN結構到LSTM結構,到GRU結構,到GNMT結構,到Transformers結構,最後到RMKV結構,下面我們一一來學習每種模型結構,務必做到對模型結構都有一個清晰的認識。

cke_139.png

1.RNN結構

RNN(Recurrent Neural Network)是迴圈神經網路的縮寫,是一種深度學習模型,特別適用於處理序列資料。RNN具有記憶功能,可以在處理序列資料時保留之前的資訊,並將其應用於當前的計算中。RNN的特點在於其具有迴圈連線的結構,使得資訊可以在網路中傳遞並被持續更新。RNN由一個個時間步組成,每個時間步的輸入不僅包括當前時刻的輸入資料,還包括上一個時間步的隱藏狀態,這樣就可以在處理序列資料時考慮到上下文資訊,這種結構使得RNN能夠處理不定長的序列資料,如自然語言文字、時間序列資料等。RNN結構如下圖所示:

cke_140.png

左邊是RNN網路,右邊是RNN網路按時序展開的形式,為什麼要按照時序展開?主要是RNN中 隱狀態更新需要依賴上一次的隱狀態資訊,就是我們理解的記憶資訊。RNN 的基本結構包括一個隱藏層,其中的神經元透過時間步驟連線,允許資訊從一個時間步驟傳遞到下一個時間步驟。RNN 在每個時間步驟上接收一個輸入並輸出一個隱藏狀態。這個隱藏狀態包含了網路在當前時間步驟所看到的序列的資訊。這個隱藏狀態可以被用作下一個時間步驟的輸入。對於一個時間步驟 t,RNN 的隱藏狀態的計算如下:

cke_141.png

儘管 RNN 具有處理序列資料的能力,但它們在處理長序列時會面臨梯度消失或梯度爆炸的問題。這是因為透過時間反向傳播時,梯度可能會迅速縮小或增大,導致模型難以學習長期依賴關係。為了解決梯度消失問題,出現了一些改進的 RNN 變體,如長短時記憶網路(LSTM)和門控迴圈單元(Gated Recurrent Unit,GRU)。這些模型透過引入門控機制,允許網路選擇性地記住和遺忘資訊,從而更有效地處理長序列。

2.LSTM結構

LSTM全稱Long Short Term Memory networks,是普通RNN的變體,可以有效解決長期依賴的問題。LSTM的核心是元胞(cell)狀態,輸入的資訊從上方的水平線經過元胞,期間只與其他實線進行少量互動,表示一些線性變換,這使得輸入的資訊能夠較完整的儲存下來,也就是說可以保留長期記憶。而LSTM對資訊進行選擇性的保留,是透過門控機制進行實現的。門結構可以控制透過元胞資訊的多少。它實際上是對輸入資訊進行線性變換後,再透過一個sigmoid層來實現的,最終將輸入最終轉為一個係數向量,值的範圍在0~1,可以理解為保留的資訊的佔比。如果值為0,則表示將對應的資訊全部丟棄,如果為1則表示將對應的資訊全部保留。LSTM共有三種門結構,分別是遺忘門、輸入門、輸出門,LSTM結構如下圖所示:

cke_142.png

遺忘門用來控制在元胞(cell) 狀態裡哪些資訊需要進行遺忘,以使cke_143.png在流動的過程中進行適當的更新。它接收cke_144.pngcke_145.png作為輸入引數,透過sigmoid層得到對應的遺忘門的引數cke_146.png。具體公式如下:

cke_147.png

接下來就需要更新細胞狀態cke_148.png了。首先LSTM需要生成一個用來更新的候選值,記為cke_149.png,透過tanh層來實現。然後還需要一個輸入門引數cke_150.png來決定更新的資訊,同樣透過sigmoid層實現。最後將cke_151.pngcke_152.png相乘得到更新的資訊,同時將上面得到的遺忘門cke_153.png和舊元胞狀態cke_154.png相乘,以忘掉其中的一些資訊,二者相結合,便得到更新後的狀態cke_155.png,具體公式如下:cke_156.png

LSTM需要計算最後的輸出資訊,該輸出資訊主要由元胞狀態cke_157.png決定,但是需要經過輸出門進行過濾處理。首先要將元胞狀態cke_158.png的值規範化到[-1,1],這透過tanh層來實現。然後依然由一個sigmoid層得到輸出門引數cke_159.png,最後將cke_160.png和規範化後的元胞狀態進行點乘,得到最終過濾後的結果,具體公式如下:

cke_161.png

3.GRU結構

GRU (Gated Recurrent Unit)是一種用於迴圈神經網路(RNN)的門控機制,旨在解決長期依賴問題並緩解梯度消失或爆炸現象。GRU的結構比LSTM (Long Short-Term Memory)更簡單,它包含兩個門:更新門(update gate)和重置門(reset gate)。更新門負責控制前一時刻的狀態資訊對當前時刻狀態的影響,其值越大,表明引入的前一時刻狀態資訊越多。重置門則控制忽略前一時刻狀態資訊的程度,其值越小,表明忽略得越多。GRU結構如下圖所示:

cke_162.png

它只有兩個門,對應輸出更新門(update gate)向量:cke_163.png和重置門(reset gate)向量:cke_164.png,更新門負責控制上一時刻狀態資訊cke_165.png對當前時刻狀態的影響,更新門的值越大說明上一時刻的狀態資訊cke_166.png帶入越多。而重置門負責控制忽略前一時刻的狀態資訊的程度,重置門的值越小說明忽略的越多。接下來,“重置”之後的重置門向量 cke_167.png與前一時刻狀態cke_168.png卷積cke_169.png,再將cke_170.pngcke_171.png輸入進行拼接,再透過啟用函式tanh來將資料放縮到-1~1的範圍內。這裡包含了輸入資料cke_172.png,並且將上一時刻狀態的卷積結果新增到當前的隱藏狀態,透過此方法來記憶當前時刻的狀態。

最後一個步驟是更新記憶階段,此階段同時進遺忘和記憶兩個步驟,使用同一個門控同時進行遺忘和選擇記憶(LSTM是多個門控制) 。

4.GNMT結構

NMT是神經網路翻譯系統,通常會含用兩個RNN,一個用來接受輸入文字,另一個用來產生目標語句,但是這樣的神經網路系統有三個弱點:1.訓練速度很慢並且需要巨大的計算資源,由於數量眾多的引數,其翻譯速度也遠低於傳統的基於短語的翻譯系統(PBMT);2.對罕見詞的處理很無力,而直接複製原詞在很多情況下肯定不是一個好的解決方法;3.在處理長句子的時候會有漏翻的現象。而GNMT中,RNN使用的是8層(實際上Encoder是9層,輸入層是雙向LSTM。)含有殘差連線的神經網路,殘差連線可以幫助某些資訊,比如梯度、位置資訊等的傳遞。同時,attention層與decoder的底層以及encoder的頂層相連線,如下圖所示:

cke_173.png

GNMT encoder將輸入語句變成一系列的向量,每個向量代表原語句的一個詞,decoder會使用這些向量以及其自身已經生成的詞,生成下一個詞。encoder和decoder透過attention network連線,這使得decoder可以在產生目標詞時關注原語句的不同部分。上面提到,多層堆疊的LSTM網路通常會比層數少的網路有更好的效能,然而,簡單的錯層堆疊會造成訓練的緩慢以及容易受到剃度爆炸或梯度消失的影響,在實驗中,簡單堆疊在4層工作良好,6層簡單堆疊效能還好的網路很少見,8層的就更罕見了,為了解決這個問題,在模型中引入了殘差連線,如圖:

cke_174.png

一句話的譯文所需要的關鍵詞可能在出現在原文的任何位置,而且原文中的資訊可能是從右往左的,也可能分散並且分離在原文的不同位置,因為為了獲得原文更多更全面的資訊,雙向RNN可能是個很好的選擇,在本文的模型結構中,只在Encoder的第一層使用了雙向RNN,其餘的層仍然是單向RNN,粉色的LSTM從左往右的處理句子,綠色的LSTM從右往左,二者的輸出先是連線,然後再傳給下一層的LSTM,如下圖Bi-directions RNN示意圖:

cke_175.png

5.Transformers結構

Transformer模型是一種基於自注意力機制的神經網路模型,旨在處理序列資料,特別是在自然語言處理領域得到了廣泛應用。Transformer模型的核心是自注意力機制(Self-Attention Mechanism),它允許模型關注序列中每個元素之間的關係。這種機制透過計算注意力權重來為序列中的每個位置分配權重,然後將加權的位置向量作為輸出。模型結構上,Transformer由一個編碼器堆疊和一個解碼器堆疊組成,它們都由多個編碼器和解碼器組成。編碼器主要由多頭自注意力(Multi-Head Self-Attention)和前饋神經網路組成,而解碼器在此基礎上加入了編碼器-解碼器注意力模組。Transformer結構如下所示:

cke_176.png

基於Transformer 結構的編碼器和解碼器結構上圖所示,左側和右側分別對應著編碼器(Encoder)和解碼器(Decoder)結構。它們均由若干個基本的Transformer 塊(Block)組成(對應著圖中的灰色框)。這裡N× 表示進行了N 次堆疊。每個Transformer 塊都接收一個向量序列。主要涉及到如下幾個模組:

1)嵌入表示層:對於輸入文字序列,首先透過輸入嵌入層(Input Embedding)將每個單詞轉換為其相對應的向量表示。通常直接對每個單詞建立一個向量表示。由於Transfomer 模型不再使用基於迴圈的方式建模文字輸入,序列中不再有任何資訊能夠提示模型單詞之間的相對位置關係。在送入編碼器端建模其上下文語義之前,一個非常重要的操作是在詞嵌入中加入位置編碼(Positional Encoding)這一特徵。具體來說,序列中每一個單詞所在的位置都對應一個向量。這一向量會與單詞表示對應相加並送入到後續模組中做進一步處理。在訓練的過程當中,模型會自動地學習到如何利用這部分位置資訊。

2)注意力層:自注意力(Self-Attention)操作是基於Transformer 的機器翻譯模型的基本操作,在源語言的編碼和目標語言的生成中頻繁地被使用以建模源語言、目標語言任意兩個單詞之間的依賴關係。給定由單詞語義嵌入及其位置編碼疊加得到的輸入表示{xi ∈ Rd}ti=1,為了實現對上下文語義依賴的建模,進一步引入在自注意力機制中涉及到的三個元素:查詢qi(Query),鍵ki(Key),值vi(Value)。在編碼輸入序列中每一個單詞的表示的過程中,這三個元素用於計算上下文單詞所對應的權重得分。直觀地說,這些權重反映了在編碼當前單詞的表示時,對於上下文不同部分所需要的關注程度。

3)前饋層:前饋層接受自注意力子層的輸出作為輸入,並透過一個帶有Relu 啟用函式的兩層全連線網路對輸入進行更加複雜的非線性變換。實驗證明,這一非線性變換會對模型最終的效能產生十分重要的影響。

cke_177.png

其中W1, b1,W2, b2 表示前饋子層的引數。實驗結果表明,增大前饋子層隱狀態的維度有利於提升最終翻譯結果的質量,因此,前饋子層隱狀態的維度一般比自注意力子層要大。

4) 殘差連線與層歸一化:由Transformer 結構組成的網路結構通常都是非常龐大。編碼器和解碼器均由很多層基本的Transformer 塊組成,每一層當中都包含複雜的非線性對映,這就導致模型的訓練比較困難。因此,研究者們在Transformer 塊中進一步引入了殘差連線與層歸一化技術以進一步提升訓練的穩定性。具體來說,殘差連線主要是指使用一條直連通道直接將對應子層的輸入連線到輸出上去,從而避免由於網路過深在最佳化過程中潛在的梯度消失問題。

Transformer 模型由於其處理區域性和長程依賴關係的能力以及可並行化訓練的特點而成為一個強大的替代方案,如 GPT-3、ChatGPT、GPT-4、LLaMA 和 Chinchilla 等都展示了這種架構的能力,推動了自然語言處理領域的前沿。儘管取得了這些重大進展,Transformer 中固有的自注意力機制帶來了獨特的挑戰,主要是由於其二次複雜度造成的。這種複雜性使得該架構在涉及長輸入序列或資源受限情況下計算成本高昂且佔用記憶體。這也促使了大量研究的釋出,旨在改善 Transformer 的擴充套件性,但往往以犧牲一些特性為代價。正是在此背景之下,一個由 27 所大學、研究機構組成的開源研究團隊,聯合發表論文《 RWKV: Reinventing RNNs for the Transformer Era 》,文中介紹了一種新型模型:RWKV(Receptance Weighted Key Value),這是一種新穎的架構,有效地結合了 RNN 和 Transformer 的優點,同時規避了兩者的缺點。RWKV能夠緩解 Transformer 所帶來的記憶體瓶頸和二次方擴充套件問題,實現更有效的線性擴充套件,同時保留了使 Transformer 在這個領域佔主導的一些性質。

四、 RWKV模型

RWKV是一個結合了RNN與Transformer雙重優點的模型架構,是一個RNN架構的模型,但是可以像transformer一樣高效訓練。RWKV 模型透過 Time-mix 和 Channel-mix 層的組合,以及 distance encoding 的使用,實現了更高效的 Transformer 結構,並且增強了模型的表達能力和泛化能力。Time-mix 層與 AFT(Attention Free Transformer)層相似,採用了一種注意力歸一化的方法,以消除傳統 Transformer 模型中存在的計算浪費問題。Channel-mix 層則與 GeLU(Gated Linear Unit)層相似,使用了一個 gating mechanism 來控制每條通道的輸入和輸出。另外,RWKV 模型採用了類似於 AliBi 編碼的位置編碼方式,將每個位置的資訊新增到模型的輸入中,以增強模型的時序資訊處理能力。這種位置編碼方式稱為 distance encoding,它考慮了不同位置之間的距離衰減特性,RWKV結構如下圖所示:

cke_178.png

這裡我們以下圖的自迴歸例子學習RWKV的推理過程,用 (x,y ) 表示樣本資料和樣本標籤,圖中有3對資料: (my,name) , (my name ,is ) 和 (my name is , Bob) . 另外,在語言模型中,標記偏移(token shift)是一種常見的技術,用於訓練模型以預測給定上下文中下一個標記(單詞、字元或子詞單元)的任務。下圖中的標記偏移技術是向右移動一個位置,生成的三個token-shift 為:"0 my","my name","name is"。為什麼要進行標記偏移呢? 這是因為這樣做具有遞迴巢狀的思想,比如:"name"向量與"my"向量有關,而"is"向量與"name"向量有關,所以"is"向量自然與"name"向量有關。好處是:給融入迴圈神經網路思想帶來了便利的同時還保持了並行性。具體流程下面的Time-Mix模組和Channel-Mix模組會詳細介紹。如下圖所示,這兩個模組是RWKV架構的主要模組。Time-Mix模組可以看成根據隱狀態(State)生成候選預測向量,Channel-Mix模組則可以看成生成最終的預測向量。

cke_179.png

1.Time Mixing模組

對於t時刻,給定單詞cke_180.png和前一個單詞cke_181.png,Time-Mix模組公式如下:

cke_182.png

其中,cke_183.png即AFT注意力機制中的cke_184.png,且cke_185.pngcke_186.png,不同的是,偏置值cke_187.png以控制cke_188.png以及距離越遠權重越低,這相當於可學習的位置編碼;u也是位置編碼,表示t時刻認為過去哪些時刻比較重要;cke_189.png表示token-shift,由cke_190.pngcke_191.png得出。因此,對於上述公式,給定的t時刻,唯一不能確定的就是cke_192.png, 因為需要得到前t-1個Key向量和Value向量。此時的Time-Mix模組看起來還是規範的注意力機制,但可以寫成遞迴迴圈的形式。Time-Mix模組之間存在著從左往右的Statas傳遞,即引入了隱狀態。其融合迴圈神經網路思想的Time-Mix模組的資料流動如下圖所示,

cke_193.png

cke_194.png的計算可以寫成遞迴迴圈的形式:

cke_195.png

其中,隱狀態cke_196.png雖然依然需要t時刻之前的計算,但過程已經簡化,可以將隱狀態的計算獨立出來,也降低了記憶體要求,並且轉化為張量積計算也大大提升了並行性。論文作者也為WKV設計了特定的CUDA核心,還提出了實踐技巧,下次討論。

2.Channel Mixing模組

Channel-Mix模組可以看成是為了生成最終的預測向量。輸入Time-Mix模組的cke_197.png向量和cke_198.png向量還沒有上文資訊,經由Time-Mix模組得出的cke_199.png向量和cke_200.png 向量融合了他們各自時刻之前的上文資訊,我猜測Channel-Mix模組如它命名所示,將不同時刻的資訊進一步融合得到最終的預測變數,其公式為:

cke_201.png

因為cke_202.png向量和cke_203.png向量具有上文的資訊,所以不用值Value向量,最後應用一個類似遺忘門的操作,丟棄不必要的歷史資訊。

3.RWKV的優勢

1)高效訓練和推理:RWKV 模型既可以像傳統 Transformer 模型一樣高效訓練,也具有類似於 RNN 的推理能力。這使得 RWKV 模型可以支援序列模式和高效推理,也可以支援並行模式(並行推理訓練)和長程記憶。

2)支援高效訓練:RWKV 模型使用了 Time-mix 和 Channel-mix 層,以消除傳統 Transformer 模型中存在的計算浪費問題。這使得 RWKV 模型在訓練過程中具有更高的效率和更快的速度。

3)支援大規模自然語言處理任務:RWKV 模型可以處理大規模的自然語言處理任務,如文字分類、命名實體識別、情感分析等。

4)可擴充套件性強:RWKV 模型具有良好的可擴充套件性,可以方便地進行模型擴充套件和改進,以適應不同任務的需求。

4.RWKV模型引數

目前官方已經就RWKV開源了多個模型。主要是Raven系列模型,Raven是基於RWKV-4架構在Pile資料集上訓練和微調的大模型,做過指令微調或者chat微調版本。此外,也包括了非Raven版本的RWKV-4的模型。

cke_204.png

五、 RWKV模型程式碼閱讀

1.RWKV模型推理程式碼

cke_205.png

程式碼解釋:

1~2行:引入程式碼需要的庫

4行:對輸出進行校驗

6~9行:載入RWKV/rwkv-4-169m-pile模型,並且輸入提示詞

11~12行:執行模型,解碼生成內容

13行:期望輸出與真實輸出內容進行校驗

2.Channel Mixing模組程式碼:

cke_206.png

x通道混合層接受與此標記對應的輸入,以及x與前一個標記對應的輸入,我們稱之為last_x。last_x儲存在這個 RWKV 層的state. 其餘輸入是學習RWKV 的 parameters。首先,我們使用學習的權重對x和進行線性插值last_x。我們將此插值x作為輸入執行到具有平方 relu 啟用的 2 層前饋網路,最後與另一個前饋網路的 sigmoid 啟用相乘(在經典 RNN 術語中,這稱為門控)。請注意,就記憶體使用而言,矩陣Wk,Wr,Wv包含幾乎所有引數(1024×1024 matrices它們是矩陣,而其他變數只是 1024 維向量)。矩陣乘法(@在 python 中)貢獻了絕大多數所需的計算。

3.Time mixing模組程式碼:

時間混合的開始類似於通道混合,透過將此標記的插入x到最後一個標記的x。然後我們應用學到的矩陣以獲得“key”, “value” and “receptance”向量。

cke_207.png

六、與其他模型的比較

1.複雜度對比

從和Transformer,Reformer,Performer,Linear Transformers,AFT-full,AFT-local,MEGA等模型的複雜度比較中可以看的出來,RWKV模型的時間複雜度和空間負責度都是最低的,費別為O(Td)和O(d),其中T 表示序列長度,d 表示特徵維度,c 表示 MEGA 的二次注意力塊大小。

cke_208.png

2.精度對比

RWKV 似乎可以像 SOTA transformer一樣縮放。至少多達140億個引數。在同等規模引數中,RWKV-4系列與Pythia和GPT-J比都是很有優勢的,對比如下圖所示:

cke_209.png

3.推理速度和記憶體佔用

RWKV網路與不同型別的Transformer效能的實驗結果對比如下圖所示。RWKV時間消耗隨序列長度是線性增加,且時間消耗遠小於各種型別的Transformer。

cke_210.png

RWKV與Transformer預訓練模型(BLOOM、OPT、Pythia)效果對比測試如下圖所示。在六個基準測試中(Winogrande、PIQA、ARC-C、ARC-E、LAMBADA 和 SciQ),RWKV 與開源二次複雜度 transformer 模型 Pythia、OPT 和 BLOOM 具有相當的競爭力。RWKV 甚至在四個任務(PIQA、OBQA、ARC-E 和 COPA)中勝過了 Pythia 和 GPT-Neo。

cke_211.png

下圖顯示,增加上下文長度會導致 Pile 上的測試損失降低,這表明 RWKV 能夠有效利用較長的上下文資訊。

cke_212.png

七、小結

本節我們學習了RWKV模型,我們掌握了RWKV模型結構的整個演進過程,從最初的RNN結構,到LSTM結構,到GRU結構,到GNTM模型,到Transformers模型,最後到RWKV模型,我們學習了每種模型結構出現的原因,以及其對應的優勢和不足。接下來,我們學習了RWKV模型,Time Mixing模組和Channel Mixing模組。我們透過學習RWKV模型的python程式碼,對RWKV模型從複雜度,精度,推理速度,記憶體佔用等四個維度和其他模型進行了對比。

透過本節學習,我們對RWKV模型有了一個全面的認識,RWKV模型正在作為一顆在大模型領域的新星正在受到越來越多社群開發者的關注,希望RWKV模型在接下來的版本迭代過程中能給大家帶來更多的驚喜。

點選關注,第一時間瞭解華為雲新鮮技術

相關文章