Transformer 的強大實力已經在諸多大型語言模型(LLM)上得到了證明,但該架構遠非完美,也有很多研究者致力於改進這一架構,比如機器之心曾報導過的 Reformer 和 Infini-Transformer。
今天我們又將介紹另一種新型 Transformer 架構:Differential Transformer(差分 Transformer,簡稱 Diff Transformer)。該架構來自微軟研究院和清華大學,有四位共一作者:Tianzhu Ye、Li Dong、Yuqing Xia、Yutao Sun。
論文標題:Differential Transformer
論文地址:https://arxiv.org/pdf/2410.05258
在 Hacker News 及 Twitter 等社交網路上,該論文都反響熱烈,有網友表示差分 Transformer 提出的改進簡單又美麗,而帶來的提升又非常顯著。
甚至已有開發者做出了差分 Transformer 的輕量實現!
差分 Transformer 的輕量實現,https://github.com/Jaykef/ai-algorithms/blob/main/DIFF_Transformer.ipynb
那麼差分 Transformer 彌補了原生 Transformer 的哪些問題呢?如下圖所示,Transformer 往往會過度關注不相關的上下文,該團隊將此稱為注意力噪聲(attention noise)。而差分 Transformer 則能放大對答案範圍的注意力並消除噪音,從而增強上下文建模的能力。這就要用到該團隊新提出的差分注意力機制(differential attention mechanism)了。
差分注意力機制可以消除注意力噪聲,鼓勵模型重點關注關鍵資訊。該方法有些類似於電氣工程中的降噪耳機和差分放大器。
下面我們就來詳細瞭解一下差分 Transformer 的設計思路。
差分 Transformer
差分 Transformer 是一種用於序列建模的基礎模型架構。為了方便說明,他們使用了僅解碼器(decoder-only)模型作為示例來描述該架構。
該模型堆疊了 L 個 Diff Transformer 層。給定一個輸入序列 x,將輸入嵌入打包成 X^0。輸入會被進一步上下文化來獲得輸出 X^L。每一層都由兩個模組組成:一個差分注意力模組和之後的前向網路模組。
相比於 Transformer,差分 Transformer 的主要差別在於使用差分注意力替換了傳統的 softmax 注意力,同時保持整體宏觀佈局不變。此外,他們也參考 LLaMA 採用了 pre-RMSNorm 和 SwiGLU 這兩項改進措施。
差分注意力
差分注意力機制的作用是將查詢、鍵和值向量對映成輸出。這裡使用查詢和鍵向量來計算注意力分數,然後計算值向量的加權和。
此處的關鍵設計是使用一對 softmax 函式來消除注意力分數的噪聲。具體來說,給定輸入 X,首先將它們投射成查詢、鍵和值 Q_1、Q_2、K_1、K_2、V。然後差分注意力運算元 DiffAttn (・) 透過以下方式計算輸出:
其中 W^Q、W^K 、W^V 是引數,λ 是可學習的標量。為了同步學習動態,將標量 λ 重新引數化為:
其中 λ_q1、λ_k1、λ_q2、λ_k2 是可學習的向量,λ_init ∈ (0, 1) 是用於初始化 λ 的常數。該團隊透過經驗發現,設定 λ_init = 0.8 − 0.6 × exp (−0.3・(l − 1)) 在實踐中效果很好,其中 l ∈ [1, L] 表示層索引。它在實驗中被用作預設策略。
他們也探索了另一種初始化策略:對所有層使用相同的 λ_init(例如 0.8)。如後面消融研究所示,使用不同的初始化策略時,效能相對穩健。
差分注意力利用兩個 softmax 注意力函式之間的差來消除注意力噪聲。這個想法類似於電氣工程中提出的差分放大器,其中兩個訊號之間的差用作輸出,這樣就可以消除輸入的共模噪聲。此外,降噪耳機的設計也基於類似的想法。
多頭差分注意力機制
該團隊也為差分注意力使用了多頭機制。令 h 表示注意力頭的數量。他們對各個頭使用不同的投影矩陣 W^Q_i 、W^K_i 、W^V_i ,i ∈ [1, h]。標量 λ 在同一層內的頭之間共享。然後對頭輸出執行歸一化,並投射成最終結果,如下所示:
其中 λ_init 是 (2) 式中的常數標量,W^O 是可學習的投影矩陣,LN (・) 是對每個頭使用 RMSNorm,Concat (・) 的作用是沿通道維度將頭連線在一起。這裡使用一個固定乘數(1 − λ_init)作為 LN (・) 的縮放尺度,以使梯度與 Transformer 對齊。
逐頭歸一化
圖 2 使用了 GroupNorm (・) 來強調 LN (・) 獨立應用於每個 head。由於差分注意力往往具有更稀疏的模式,因此頭之間的統計資訊更加多樣化。為了改進梯度的統計情況,LN (・) 運算元會在連線操作之前對每個頭進行歸一化。
整體架構
其整體架構會堆疊 L 層,其中每層包含一個多頭差分注意力模組和一個前向網路模組。如此,便可將差分 Transformer 層描述為:
其中 LN (・) 是 RMSNorm,SwiGLU (X) = (swish (XW^G) ⊙ XW_1) W_2,且 W^G、W_1、W_2 是可學習的矩陣。
實驗
該團隊從以下角度評估了差分 Transformer 在 LLM 中的應用,包括對比評估、應用評估和消融研究。這裡我們僅關注實驗結果,更多實驗過程請訪問原論文。
語言建模評估
該團隊評估了差分 Transformer 的語言建模能力。為此,他們使用 1T token 訓練了一個 3B 大小的差分 Transformer 語言模型,並與之前的 Transformer 語言模型做了比較。
結果見表 1,其中報告的是在 LM Eval Harness 基準上的零樣本結果。
可以看到,3B 規模下,差分 Transformer 語言模型的表現優於之前的 Transformer 語言模型。此外,實驗也表明差分 Transformer 在多種任務上都勝過 Transformer,詳見原論文附錄。
與 Transformer 的可擴充套件性比較
該團隊也比較了新舊 Transformer 的可擴充套件性。結果見圖 3,其中 a 比較了模型規模方面的可擴充套件性,而 b 則是訓練 token 數量方面的可擴充套件性。
可以看到,在這兩個方面,差分 Transformer 的可擴充套件性均優於常規 Transformer:僅需後者 65% 左右的模型大小或訓練 token 數量就能達到相媲美的效能。
長上下文評估
當 3B 模型上下文長度增長至 64K,模型的表現又如何呢?又使用另外 1.5B token 訓練了 3B 版本的檢查點模型之後,該團隊發現隨著上下文長度的增加,累積平均負對數似然(NLL)持續下降。差分 Transformer 得到的 NLL 值低於常規 Transformer。見圖 4,這樣的結果表明,差分 Transformer 可以有效地利用不斷增加的上下文。
關鍵資訊檢索
為了檢驗差分 Transformer 檢索關鍵資訊的能力,該團隊執行了 Needle-In-A-Haystack(草堆找針)測試。
表 2 給出了 4K 上下文長度的情況,其中 N 是針的數量,R 是查詢引用的數量。可以看到,差分 Transformer 的多針檢索準確度高於常規 Transformer,尤其是當針數量較多時,差分 Transformer 的優勢會更加明顯。
那麼當上下文長度提升至 64K 時,又會如何呢?結果見圖 5,這裡使用的上下文長度在 8K 到 64K 之間,使用了 N = 8 和 R = 1 的設定。
可以看到,在不同的上下文長度下,差分 Transformer 能夠保持相對穩定的效能。而當上下文長度越來越大時,常規 Transformer 的效能會逐漸下降。
另外,表 3 展示了分配給關鍵資訊檢索任務的答案範圍和噪聲上下文的注意力分數。該分數可代表模型保留有用資訊、抵抗注意力噪聲的能力。
可以看到,相比於常規 Transformer,差分 Transformer 能為答案範圍分配更高的注意力分數,同時為注意力噪聲分配更低的注意力分數。
上下文學習能力評估
該團隊從兩個角度評估模型的上下文學習能力,包括多樣本分類和上下文學習的穩健性。
圖 6 展示了新舊 Transformer 模型的多樣本分類結果。結果表明,在不同的資料集和不同的演示樣本數量上,差分 Transformer 均穩定地優於 Transformer。此外,差分 Transformer 的平均準確度優勢也很明顯,從 5.2% 到 21.6% 不等。
圖 7 則展示了兩種模型的上下文學習穩健性結果。該分析基於 TREC 資料集,並且採用了兩種提示詞格式:示例隨機排列(圖 7a)和按類別交替排列(圖 7b)。
在這兩種設定下,差分 Transformer 的效能方差要小得多。結果表明,新方法在上下文學習任務中更為穩健。相比之下,Transformer 容易受到順序排列的影響,導致最佳結果與最差結果之間差距巨大。
上下文幻覺評估
該團隊基於文字摘要和問答任務評估了模型的上下文幻覺現象。結果見表 4。
可以看到,相比於常規 Transformer,差分 Transformer 在摘要和問答任務上的上下文幻覺更低。該團隊表示,原因可能是差分 Transformer 能更好地關注任務所需的基本資訊,而不是無關上下文。
啟用異常值分析
在 LLM 中,一部分啟用值明顯大於大多數啟用值的現象被稱為啟用異常值(activation outliers)。異常值導致訓練和推理過程中模型量化困難。實驗表明差分 Transformer 可以降低啟用異常值的幅度,從而可能實現更低的量化位寬。
表 5 展示了兩個訓練得到 Transformer 和差分 Transformer 模型的啟用值統計情況。這裡分析了兩種型別的啟用,包括注意力 logit(即 pre-softmax 啟用)和隱藏狀態(即層輸出)。可以看到,儘管中位數相似,但與 Transformer 相比,差分 Transformer 的較大啟用值要低得多。這表明新方法產生的啟用異常值較少。
圖 8 則展示了將注意力 logit 量化到更低位的情況。這裡使用的方案是:使用 absmax 量化的動態後訓練量化。其中,16 位配置表示未經量化的原始結果。模型逐步量化為 8 位、6 位和 4 位。這裡報告的是在 HellaSwag 上的零樣本準確度,但該團隊也指出在其它資料集上也有類似表現。
從圖中可知,即使降低位寬,差分 Transformer 也能保持較高效能。相較之下,常規 Transformer 的準確度在 6 位和 4 位量化時會顯著下降。這一結果表明,差分 Transformer 本身就能緩解注意力分數中的啟用異常值問題,從而可為低位 FlashAttention 的實現提供新機會。
最後,該團隊也進行了消融實驗,證明了各個新設計的有效性。