DenseMamba:大模型的DenseNet時刻,Mamba和RetNet精度顯著提升

机器之心發表於2024-03-11
近期,來自華為諾亞方舟實驗室的研究者提出了 DenseSSM,用於增強 SSM 中各層間隱藏資訊的流動。透過將淺層隱藏狀態有選擇地整合到深層中,DenseSSM 保留了對最終輸出至關重要的精細資訊。DenseSSM 在保持訓練並行性和推理效率的同時,透過密集連線實現了效能提升。該方法可廣泛應用於各種 SSM 型別,如 Mamba 和 RetNet。

隨著 ChatGPT 的突破性進展,大型語言模型(LLMs)迎來了一個嶄新的里程碑。這些模型在語言理解、對話互動和邏輯推理方面展現了卓越的效能。過去一年,人們目睹了 LLaMA、ChatGLM 等模型的誕生,它們基於 Transformer 架構,採用多頭自注意力(MHSA)機制來捕捉詞彙間的複雜關係,儘管 MHSA 模組在模型中扮演著核心角色,但其在推理過程中對計算和記憶體資源的需求卻極為龐大。具體來說,對於長度為 N 的輸入句子,自注意力的計算複雜度高達 O (N^2),而記憶體佔用則達到了 O (N^2D),其中 D 是模型的維度。

為了應對這一挑戰,最新的研究致力於簡化 Transformer 架構,以降低其在計算和空間上的複雜度。研究者們探索了多種創新方法,包括卷積語言模型、迴圈單元、長上下文模型,以及狀態空間模型(SSMs)。這些新興技術為構建高效能的 LLMs 提供了強有力的替代方案。SSMs 透過引入高效的隱藏狀態機制,有效處理長距離依賴問題,同時保持了訓練的並行性和推理的高效率。隱藏狀態能夠在時間維度上傳遞資訊,減少了在每一步中訪問歷史詞彙的計算負擔。透過狀態轉移引數 A,隱藏狀態能夠將前一時間步的資訊傳遞至當前時間步,實現對下一個詞彙的自迴歸預測。

儘管隱藏狀態在 SSMs 中起著至關重要的作用,但其在以往的研究中並未得到充分研究。不同層的權重和隱藏特徵包含了從細粒度到粗粒度的多層次資訊。然而,在早期的 SSMs 版本中,隱藏狀態僅在當前層內流動,限制了其傳遞更深層資訊的能力,從而影響了模型捕獲豐富層次資訊的能力。

為了解決這個挑戰,華為諾亞方舟實驗室的科研團隊發表了新工作《DenseMamba: State Space Models with Dense Hidden Connection for Efficient Large Language Models》, 提出一個適用於各類 SSM 模型例如 Mamba 和 RetNet 的 DenseSSM 方法,該方法有選擇地將淺層隱藏狀態整合到深層,保留了對最終輸出至關重要的淺層細粒度資訊,以增強深層感知原始文字資訊的能力。

圖片

  • 論文連結:https://arxiv.org/abs/2403.00818

  • 專案主頁:https://github.com/WailordHe/DenseSSM

文章首先分析了狀態空間模型(SSMs)中的隱藏狀態退化問題,

圖片

上標 “l” 表示第 l 個塊。其中,Θ(·) 是從 SSM 模組的最後一個輸出到輸入的轉換,例如卷積和前饋網路(FFN)。從公式 (7) 可以看出,從第 (l-m) 層到第 l 層的隱藏資訊傳遞需要經過 m 個變換塊和 m 次 BC 矩陣乘法。這樣複雜的計算過程可能導致顯著的資訊丟失,這意味著在第 l 層嘗試檢索淺層的某些資訊變得非常困難和不清晰。

方法

密集(Dense)隱藏層連線

在上述分析中發現隨著層深度的增加,SSM 中重要隱藏狀態的衰減。因此,DenseSSM 提出了一種密集連線的隱藏狀態方法,以更好地保留來自淺層的細粒度資訊,增強深層感知原始文字資訊的能力。對於第 l 個塊,DenseSSM 在其前 m 個塊中密集連線隱藏狀態。

圖片

首先,收集淺層隱藏狀態,並引入一個選擇性轉換模組 φ,同時將它們投影到目標層的子空間並選擇有用的部分:

圖片

操作圖片是融合中間隱藏向量和當前隱藏狀態的函式。具有所提出的密集隱藏層連線的 SSM 被稱為 DenseSSM, 下圖為遞迴模式的 DenseSSM 示例。

圖片

DenseSSM 也可以基於卷積模式以實現高效訓練。根據狀態空間模型(SSM)的公式圖片可以得到:

圖片

這個過程可以透過對輸入序列圖片進行卷積來實現:

圖片

在文章所提出的 DenseSSM 中,可以獲得隱藏狀態加強的 SSM 的輸出:

圖片

DenseSSM 方法的並行實現示例圖:

圖片

Selective Transition Module (選擇性轉換模組)

圖片

選擇性轉換模組 φ(·) 的目的是將輸入投影到目標子空間,並同時選擇隱藏資訊的有用部分。透過投影層和門控選擇機制實現了選擇性轉換模組,如上圖所示。首先,前 m 個 SSM 塊中的隱藏狀態會被投影到相同的空間:

圖片

然後,根據輸入圖片生成門控權重,並使用它們來選擇有用的隱藏狀態:

圖片

在實踐中作者保持了簡單且高效的實現。投影層使用線性變換實現,而門控模組則使用引數高效的帶有啟用函式的兩層 MLP。

Hidden Fusion Module (隱藏層融合模組)

選擇性轉換模組後從淺層獲得了選擇的隱藏狀態,即圖片後,DenseSSM 方法利用一個隱藏融合模組將這些精選的淺層隱藏狀態與當前層的隱藏狀態結合起來。由於這些精選狀態已經被投影到相同的空間,因此可以簡單地將它們累加到當前層的隱藏狀態上:

圖片

為了保持模型的高效性,其他可能的實現方式,例如拼接和交叉注意力機制沒有被使用。

擴充套件到 RetNet

RetNet 可以被視為一種狀態空間模型,它利用線性注意力來簡化自注意力的計算複雜度。與標準 Transformer 相比具有快速推理和並行化訓練兼得的優勢。

圖片

其中,圖片是迴圈狀態, RetNet 的密集 KV 連線執行方式如下。首先,淺層的 K 和 V 被連線起來:

圖片

然後,這些 K 和 V 被注入到當前層的原始鍵(或值)中:

圖片

配備了使用所提出 DenseSSM 方法的密集鍵值(KV)連線的 RetNet 被稱為 DenseRetNet,如下圖所示。

圖片

此外,DenseRetNet 也可以在並行模式下實現,也就是說,可以在 GPU 或 NPU 上並行訓練。DenseRetNet 的並行模式公式如下:

圖片

實驗

文章進行了全面的實驗,以驗證所提出的 DenseSSM 的有效性。這些實驗在不同的架構上進行,包括 RetNet 和 Mamba。

預訓練資料

在實驗中,選擇了 The Pile 資料集的一個子集,並從頭開始訓練所有模型。為了確保訓練集包含 150 億(15B)個 tokens,對資料集進行了隨機抽樣。在所有實驗中,統一使用了 LLaMA 分詞器來處理這些資料。

評估資料集

在評估模型效能時,特別關注了模型在多種下游任務上的零樣本和少樣本學習能力。這些任務包括了一系列測試常識推理和問答的資料集,例如 HellaSwag、BoolQ、COPA、PIQA、Winograd、Winogrande、StoryCloze、OpenBookQA、SciQ、ARC-easy 和 ARC-challenge。此外,文章還報告了 WikiText 和 LAMBADA 的詞困惑度指標。所有評估都透過使用 LM evaluation harness 標準化的評估工具進行,以確保評估模型能力的一致性。

實驗設定

為了驗證提出的 DenseSSM 機制的有效性,選擇了 350M 和 1.3B 兩種模型規格進行實驗。所有模型都是從頭開始訓練的,並進行了一個 Epoch 的訓練,共使用了 1.5B tokens。訓練時,設定訓練的 batch size 為 0.5M,序列長度為 2048 個 token。訓練過程中使用了 AdamW 最佳化器,並採用了多項式學習率衰減,warm-up 比例設定為總訓練步數的 1.5%。權重衰減設定為 0.01,梯度裁剪設定為 1。

DenseRetNet 的實驗

DenseRetNet 模型的大小和超引數設定詳細列出如下。此外,DenseRetNet 模型中還進一步整合了全域性注意力單元(GAU)。GAU 將注意力機制與前饋網路(FFN)塊結合為一個單元,這使得模型能夠同時進行通道混合和 token 混合。與原始的 GAU 不同,多頭機制仍然被採用以實現多尺度的指數衰減,這種設計旨在提高模型對不同尺度特徵的捕捉能力,從而提升效能。

圖片

在通用語料庫以及包括常識推理和問答在內的多種下游任務上,對 DenseRetNet 模型進行了評估。實驗結果的比較表格顯示,DenseRetNet 模型在 Wikitext 和 LAMBADA 語料庫上取得了更低的困惑度。此外,在零樣本和少樣本設定的下游任務中,DenseRetNet 表現出了顯著的優勢。與 RetNet 相比,DenseRetNet 顯著提升了效能,並且在與基於 Transformer 的語言模型的比較中,實現了更優越的效能表現。這些結果表明,DenseRetNet 在處理自然語言處理任務時,具有強大的能力和潛力。

圖片

DenseMamba 的實驗

下表詳細列出了 DenseMamba 模型的引數設定。由於 DenseMamba 使用的分詞器相比於 Mamba 模型中使用的 GPT-NeoX 分詞器規模較小,為了使引數數量相匹配,作者在模型中增加了兩層。除此之外,模型結構和其他訓練設定均遵循了 Mamba 論文中的描述。具體而言,對於 360M 引數的模型,學習率被設定為 3e-4;對於 1.3B 引數的模型,學習率被設定為 2e-4。在這兩種情況下,均沒有采用 dropout 技術。

圖片

下表比較了 DenseMamba 與相對應模型的效能。DenseMamba 在測試集上表現出卓越的困惑度和準確性,優於 Mamba 和其他基於 Transformer 的模型。

圖片

總結

文章提出了一個新的框架 ——DenseSSM(密集狀態空間模型),旨在透過增強隱藏資訊在不同層之間的流動來提升狀態空間模型(SSM)的效能。在 SSM 中,隱藏狀態是儲存關鍵資訊的核心單元,更有效地利用這些狀態對於模型的基本功能至關重要。為了實現這一目標,作者提出了一種方法,即從淺層收集隱藏狀態,並將它們有選擇性地融合到深層的隱藏狀態中,這樣可以增強 SSM 對文字低層資訊的感知能力。

DenseSSM 方法的設計考慮到了保持 SSM 原有的優點,如高效的自迴歸推理能力和高效的並行訓練特性。透過將 DenseSSM 方法應用於流行的架構,例如 RetNet 和 Mamba,作者成功地創造了具有更強大的基礎語言處理能力的新架構。這些新架構在公共基準測試中表現出了更高的準確性,證明了 DenseSSM 方法的有效性。

相關文章