阿里雲通義大模型新技術:MoE模型訓練專家平衡的關鍵細節

机器之心發表於2025-01-24

圖片

AIxiv專欄是機器之心釋出學術、技術內容的欄目。過去數年,機器之心AIxiv專欄接收報導了2000多篇內容,覆蓋全球各大高校與企業的頂級實驗室,有效促進了學術交流與傳播。如果您有優秀的工作想要分享,歡迎投稿或者聯絡報導。投稿郵箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

本週,在阿里雲通義千問 Qwen 團隊提交的一篇論文中,研究人員發現了目前最熱門的 MoE(混合專家模型)訓練中存在的一個普遍關鍵問題,並提出一種全新的方法——透過輕量的通訊將區域性均衡放鬆為全域性均衡,使得 MoE 模型的效能和專家特異性都得到了顯著的提升。

圖片
  • 論文:《Demons in the Detail: On Implementing Load Balancing Loss for Training Specialized Mixture-of-Expert Models》

  • 論文連結:https://arxiv.org/abs/2501.11873

圖片
MoE 模型訓練中的關鍵問題

混合專家模型(MoEs)透過路由機制動態並稀疏地啟用模型引數,使得能高效地增大模型引數規模。基於 TopK 機制的稀疏啟用會在訓練中會遇到專家啟用不均衡的問題:少數被頻繁選擇的專家會被最佳化得更多,進一步使得這些專家被更頻繁地選擇,最終導致只選擇少數專家,造成剩餘專家的冗餘。因此,MoE 在訓練中需要引入額外輔助的負載均衡損失(load balance loss,LBL)來鼓勵專家的選擇趨於均衡。

目前主流 MoE 訓練框架中實現的 LBL 的最佳化目標是區域性(micro-batch)的負載均衡,這使得模型需要將一個micro-batch的輸入都均勻分配給不同的專家。然而,一個micro-batch的輸入往往只來自個別領域,區域性負載均衡會讓模型將每個領域的輸入都均勻分配。這種均勻分配會阻礙某些專家更多處理特定領域的資料,也即阻礙專家出現領域層次的分化特徵。我們發現,將區域性的負載均衡放鬆到全域性的負載均衡,能顯著增強專家的特異化並提高模型效能。

背景

混合專家(Mixture-of-Experts,MoE)是一種高效的在訓練時擴充套件模型引數規模的技術。通常,一個MoE層由一個路由器(通常是一個線性層)和一組專家組成(對於Transformer的模型,每個專家是一個前饋神經網路)。給定一個輸入,只有部分專家會被啟用,然後它們的輸出會根據路由器分配的權重進行聚合。具體來說:

圖片

負載均衡損失

負載均衡損失是訓練 MoE 網路中的一種重要正則化技術,其核心思想是鼓勵所有專家的均衡啟用。它可以透過以下公式計算:

圖片

其中, 是專家 的啟用頻率, 是分配給專家 的平均路由分數。

然而,大多數現有的MoE訓練框架(例如Megatron-core)實現的是區域性(micro-batch)層次的均衡,這意味著在每個 micro-batch 內計算 LBL ,然後在全域性(global-batch)層次上進行平均,即:

圖片

其中 為 micro-batch 數, 是在第 個 micro-batch 上計算的負載均衡損失, 為在第 個 micro-batch 上統計出的啟用頻率和路由分數。

我們關注的關鍵點是,如果一個 micro-batch 中的資料不夠多樣化,這種實現方式可能會阻礙專家的特異化。例如,假設一個 micro-batch 中只包含程式碼資料,上述負載均衡損失仍然會推動路由器將這些程式碼輸入均勻分配給所有專家。而理想狀況下,處理程式碼資料的專家網路應該對程式碼資料有更高的啟用頻率。在訓練基於 MoE 的大型語言模型時,這種情況更常見:一個較小的 micro-batch (通常為 1)中的資料通常來自同一領域。這在一定程度上解釋了為什麼當前大多數基於 MoE 的大語言模型中都沒有觀察到明顯的領域層次的專家特異化。

這一缺點促使我們將當前區域性均衡的方法想辦法擴充套件到全域性(global-batch)均衡。

從區域性均衡到全域性均衡

得得益於 LBL 計算的格式,我們可以透過通訊不同節點的 來將區域性 轉化為全域性的 :1)在所有 micro-batch 之間同步專家選擇頻率 ;2)在每個GPU上計算負載均衡損失;3)在所有 micro-batch 之間聚合損失。具體來說:

圖片

其中 是對全域性統計的啟用頻率和門控分數,第一個等式為 的計算方式,第二個等式為全域性路由分數可以由區域性路由分數平均而來,第三個等式表示用全域性啟用頻率參與區域性計算後再平均聚合等價於全域性均衡損失。因為 只是一個專家數大小的向量,即使是在全域性通訊的情況下也不會帶來明顯的開銷。此外由於 LBL 的計算與模型其它部分的計算相對獨立,還可以用計算掩蓋等策略進一步消除同步 的通訊開銷。

此外,對於需要梯度積累的情景,我們還提出了快取機制來累積各個積累步統計的專家啟用頻率,使得計算節點較少、只進行一次通訊達到的均衡範圍有限的情況下,也能逐漸近似全域性統計的啟用頻率。

擴大均衡的範圍帶來穩定的提升

我們在三種引數規模(3.4B 啟用 0.6B, 15B 啟用 2.54B,43B 啟用 6.6B)下分別訓練了 120B 和 400B tokens,對比了不同的均衡範圍(Balance BSZ)對模型效能的影響。所有模型都使用了細粒度專家、共享專家及 dropless 策略(專家不會拋棄超過容量的tokens)。可以看到,將均衡範圍從一般框架實現的 4,8 或者 16 增大到 128 以上後模型在 Benchmark 指標和 PPL 都有明顯提升。
圖片
我們在 3.4B 啟用 0.6B 的模型訓練 400B tokens 到設定上進一步對比了模型效果隨著均衡範圍的變化,可以看到 balance BSZ 從 2 到 128 模型的 PPL 在快速降低,在 128 後逐漸飽和。目前主流 MoE 框架中即使是進行了機內通訊,對於較大的模型 balance BSZ 也一般在 8 到 16 的,這進一步體現了我們通訊方法的意義。
圖片

分析實驗

假設驗證

前文提到,這篇工作的出發點是在一個 micro-batch 中,資料的來源較為單一的,進而導致 MoE 模型需要將類似來源的資料均勻分配到所有expert上,我們改進了這一點進而得到了提升。

然而,我們也可以假設 global batch 是因為使用了更多的 token 來統計 expert 啟用頻率進而減少了方差,使得負載均衡損失更加穩定,進而提升訓練洗哦啊過。位了更加嚴謹地對比這兩種假設,我們引入了一種對比的實驗設定:Shffuled batch balance, 即我們從global batch中隨機抽取一個子集(這個子集的大小等於micro batch的大小)統計專家啟用頻率,進而計算負載均衡損失。Shuffled batch balance 和 micro-batch balance擁有相同的token數目,和 global-batch balance擁有相同的token分佈。
圖片
我們發現,shuffled batch balance 和 global batch balance 的表現幾乎一致,都顯著好於 micro batch balance。說明,引入 global-batch 獲得提升的首要原因是在一個更加通用、多樣的 token 集合上計算損失。進而驗證了我們的出發點和假設。

新增少量區域性均衡損失

能提高模型效率

只使用全域性均衡會導致區域性均衡狀況有所降低,這會一定程度影響 MoE 的計算效率。我們進一步實驗了在主要使用全域性均衡的情況下,在訓練過程中新增區域性均衡(預設實現的 LBL,損失權重為全域性 LBL 的 1%)限制對於模型效能和效率的影響。可以看到,新增區域性均衡能提升模型的速度(每個更新步耗時從 1.64秒提升到1.59秒),同時模型的效果也幾乎不受影響。
圖片

同期相關工作以及討論

已有工作 GRIN 也提出了 Global Load Balance Loss Adaptations,然而更多將這一均衡方法作為訓練框架只使用張量並行、不使用專家並行的優勢。GRIN 中並沒有從 specialization 或是對模型 performance 影響等方面討論使用 Global Load Balance 的動機,也沒有展示單一使用 Global Load Balance 的影響。

Wang et al. 提出在基於MoE的大語言模型訓練中,負載均衡損失和語言模型損失如同槓桿一樣需要權衡,因為兩者的最佳化目標並不一致。因此,他們提出了一種基於專家選擇頻率更新的偏差項(bais term),在不改變路由分數的情況下平衡專家選擇,從而去掉了用來輔助訓練的負載均衡損失(auxiliary-loss free)。基於專家選擇頻率更新的偏置項,以在不改變路由評分的情況下平衡專家選擇。但是,他們沒有比較該方法在專家選擇頻率是根據 micro-batch 計算和根據 global-batch 計算時的效能差異。

這項工作也被應用到 deepseek-v3 的訓練中。deepseek-v3 的技術報告(同期工作)中強調了這項技術的專家選擇頻率是基於 global-batch 進行計算,並在小規模上討論了基於global batch 使用 LBL 的結果,也發現這兩種方法結果相似。

而我們的工作不僅在大規模上系統驗證了這種方法的有效性,還詳細析了均衡範圍對效能的影響,並消融證明了 global-batch 是透過納入更多樣化的領域資訊從而顯著提效能。

結論

我們回顧了目前 MoE 訓練框架中均衡損失,發現目前的實現方式會將所有來自相同領域的區域性輸入都均勻分配,限制了專家的分化。透過輕量的通訊將區域性均衡放鬆為全域性均衡,MoE 模型的效能和專家特異性都得到了顯著的提升。我們認為這一進展解決了現有MoE訓練中的一個關鍵問題,為MoE模型的最佳化提供了新的視角,並有助於構建更加可解釋的模型。儘管我們的實驗主要集中在基於語言的任務上,我們希望我們的工作能夠為在不同領域訓練更大規模、更有效的 MoE 模型提供幫助。

相關文章