位元組豆包大模型團隊突破殘差連線侷限!預訓練收斂最快加速80%

机器之心發表於2024-11-07
位元組跳動豆包大模型團隊於近日提出超連線(Hyper-Connections),一種簡單有效的殘差連線替代方案。面向殘差連線的主要變體的侷限問題,超連線可透過動態調整不同層之間的連線權重,解決梯度消失和表示崩潰(Representation Collapse)之間的權衡困境。在 Dense 模型和 MoE 模型預訓練中,超連線方案展示出顯著的效能提升效果,使收斂速度最高可加速 80%。

自從 ResNet 提出後,殘差連線已成為深度學習模型的基礎組成部分。其主要作用是 —— 緩解梯度消失問題,使得網路的訓練更加穩定。

但是,現有殘差連線變體在梯度消失和表示崩潰之間存在一種 “蹺蹺板式” 的權衡,無法同時解決。

為此,位元組豆包大模型 Foundation 團隊於近日提出超連線(Hyper-Connections),針對上述 “蹺蹺板式” 困境,實現了顯著提升。

該方法適用於大規模語言模型(LLMs)的預訓練,在面向 Dense 模型和 MoE 模型的實驗中,展示了顯著效能提升效果,使預訓練收斂速度最高可加速 80%。
圖片
研究團隊還發現,超連線在兩個小型的視覺任務中表現同樣優異,這表明,該方法在多個領域有廣泛的應用前景。
圖片
  • 論文標題:Hyper-Connections
  • 論文連結:https://arxiv.org/pdf/2409.19606

1. 超連線的核心思想

前文提及,殘差連線的兩種主要變體 Pre-NormPost-Norm 各自都有其侷限性,具體體現如下:

  • Pre-Norm:在每個殘差塊之前進行歸一化操作,可有效減少梯度消失問題。然而,Pre-Norm 在較深網路中容易導致表示崩潰,即深層隱藏表示過於相似,從而削弱了模型學習能力。
  • Post-Norm:在殘差塊之後進行歸一化操作,有助於減少表示崩潰問題,但也重新引入梯度消失問題。在 LLM 中,通常不會採用此方法。

超連線的核心思路在於 —— 引入可學習的深度連線(Depth-connections)和寬度連線(Width-connections)。

從理論上,這使得模型不僅能夠動態調整不同層之間的連線強度,甚至能重新排列網路層次結構,彌補了殘差連線在梯度消失和表示崩潰(Representation Collapse)之間的權衡困境。

深度連線與寬度連線

起初,該方法會將網路輸入擴充套件為 n 個隱向量(n 稱作 Expansion rate)。之後每一層的輸入都會是 n 個隱向量,超連線會對這些隱向量建立以下兩類連線:

  • 深度連線(Depth-Connections):這些連線類似於殘差連線,只為輸入與輸出之間的連線分配權重,允許網路學習不同層之間的連線強度。
  • 寬度連線(Width-Connections):這些連線使得每一層多個隱藏向量之間可進行資訊交換,從而提高模型表示能力。
圖片
靜態與動態超連線

超連線可以是靜態的,也可以是動態的。

其中,靜態超連線(Static Hyper-Connections, SHC)意味著連線權重在訓練結束後固定不變。而動態超連線(Dynamic Hyper-Connections, DHC)則對應連線權重可根據輸入動態調整。實驗表明,動態超連線效果更好。

2. 技術細節

超連線(Hyper-connections)

首先,考慮第 k 層的輸入隱藏向量圖片,網路的初始輸入為圖片,並將其複製 n 次,形成初始的超隱藏矩陣(Hyper Hidden Matrix):
圖片
這裡,n 稱為擴充套件率(Expansion Rate)。在第 k 層,輸入是上一層的超隱藏矩陣圖片,即:
圖片
對最後一層的超隱藏矩陣逐行求和,得到所需的隱藏向量,並透過一個投影層輸出網路最終的結果(在 Transformer 中即為歸一化層和解嵌入層)。

為了簡化後續分析的符號表示,作者省略層索引,直接將超隱藏矩陣表示為:
圖片
超連線可以用一個矩陣來表示,對於擴充套件率為 n 的情況,超連線矩陣 HC 如下:
圖片
考慮一層網路圖片,它可能是 Transformer 中的 attention 層或者是 FFN 層。超連線的輸出 圖片可以簡單地表示為:
圖片
也就是說,用 圖片作為權重對輸入 圖片進行加權求和,得到當前層的輸入圖片
圖片同時,圖片用於將 圖片對映到殘差超隱藏矩陣圖片,表示如下:
圖片
最終的輸出表示式為:
圖片
虛擬碼如下:
圖片
動態超連線的實現

超連線矩陣 圖片的元素可以動態依賴於輸入 圖片,動態超連線的矩陣表示為:
圖片
同樣,給定層 圖片和輸入圖片,可以得到動態超連線的輸出:
圖片
在實際操作中,團隊結合了靜態和動態矩陣來實現動態超連線,動態引數透過線性變換獲得。

為了穩定訓練過程,團隊線上性變換前引入歸一化,並在其後應用 tanh 啟用函式,透過一個可學習的小因子進行縮放。動態引數的計算公式如下:
圖片
實驗表明,動態超連線在語言建模任務中優於靜態超連線。

3. 為什麼使用超連線(Hyper-Connections)

研究團隊認為,殘差連線的兩種變體,即前歸一化(Pre-Norm)和後歸一化(Post-Norm),可以被視為不可訓練的超連線。

隨後,團隊引入了順序 - 並行二象性概念,展示了超連線如何動態最佳化層的排列以提升網路效能。

殘差連線是不可訓練的超連線

前歸一化和後歸一化的殘差連線可以表示為以下擴充套件率為 圖片的超連線矩陣:
圖片
其中,圖片圖片 分別表示神經網路層輸入和輸出的標準差,圖片表示它們之間的協方差。

對於 Pre-Norm,其超連線矩陣是一個 圖片的矩陣,右下三角部分填充為 1,其餘部分為佔位符 0。對於 Post-Norm,權重依賴於輸入和輸出的方差及協方差,形成一個 圖片的矩陣。因此,它們的超連線矩陣是不可訓練的。

而本工作提出的方法的超連線矩陣是 圖片矩陣,且權重是可訓練的,甚至可以基於輸入進行動態預測。

順序 - 並行二象性

給定一系列神經網路模組,我們可以將它們順序排列或並行排列。作者認為,超連線可以學習如何將這些層重新排列,形成順序和並行配置的混合。
圖片
在不失一般性的情況下,可以將擴充套件率設定為 n=2。如果超連線以如下矩陣形式學習,神經網路將被順序排列:
圖片
在這種情況下,深度連線退化為殘差連線,如圖 (a) 所示。

當奇數層和偶數層的超連線矩陣分別定義為以下形式時,神經網路每兩層將被並行排列,類似於 Transformer 中的 parallel transformer block 的排列方式,如圖 (b) 所示。
圖片
因此,透過學習不同形式的超連線矩陣,網路層的排列可以超越傳統的順序和並行配置,形成軟混合甚至動態排列。對於靜態超連線,網路中的層排列在訓練後保持固定;而對於動態超連線,排列可以根據每個輸入動態調整。

4. 實驗結果

實驗主要集中在大規模語言模型的預訓練上,涵蓋了 Dense 模型和 MoE 模型。

實驗結果表明,使用超連線的模型顯著優於使用殘差連線的模型。

1B Dense 模型實驗
圖片
只要擴充套件率 > 1,效果就十分顯著,且訓練更穩定,消掉了訓練 loss 的 spikes。

7B Dense 模型實驗

團隊甚至 Scale 到了 7B 模型,效果也十分亮眼,同時可以看到有超連線的網路訓練更穩定。
圖片
7B 候選啟用 1.3B 的 MoE 模型實驗
圖片
可以看到,下游指標全漲,在 ARC-Challenge 上甚至漲了 6 個百分點。
圖片
綜上,研究團隊介紹了超連線(Hyper-Connections),它解決了殘差連線在梯度消失和表示崩潰之間的權衡問題。實驗結果表明,超連線在大規模語言模型的預訓練以及視覺任務中都表現出顯著的效能提升。

值得注意的是,超連線的引入幾乎不增加額外的計算開銷或引數量,團隊認為,該成果具有廣泛的應用潛力,可以推廣到文音檢視模態的不同任務上,包括多模態理解、生成基座模型等。

5. 寫在最後

團隊關注底層問題,尤其在 LLMs 和多模態方面,期望實現更多突破。

更多團隊技術研究進展,可以進入「豆包大模型團隊」技術解讀欄目瞭解。

相關文章