線上和離線對齊演算法的效能差距根源何在?DeepMind實證剖析出爐
在 AI 對齊問題上,線上方法似乎總是優於離線方法,但為什麼會這樣呢?近日,Google DeepMind 一篇論文試圖透過基於假設驗證的實證研究給出解答。
論文標題:Understanding the performance gap between online and offline alignment algorithms
論文地址:https://arxiv.org/abs/2405.08448
根據人類反饋的強化學習(RLHF)隨著大型語言模型(LLM)發展而日漸成為一種用於 AI 對齊的常用框架。不過近段時間,直接偏好最佳化(DPO)等離線方法異軍突起 —— 無需主動式的線上互動,使用離線資料集就能直接對齊 LLM。這類方法的效率很高,也已經得到實證研究的證明。但這也引出了一個關鍵問題:
AI 對齊是否必需線上強化學習?
對於這個問題,人們希望既知道其理論上的答案,也希望明晰實驗給出的解答。
從實證角度看,相比於大家常用的線上 RLHF(由偏好建模和從模型取樣組成),離線演算法實現起來要簡單得多,成本也低得多。因此,收集有關離線演算法的充分性的證據可讓 AI 對齊變得更加簡單。另一方面,如果能明晰常用線上 RLHF 的優勢,也能讓我們理解線上互動的基本作用,洞見離線對齊方法的某些關鍵挑戰。
線上演算法與離線演算法的對比
要公平地比較線上和離線演算法並非易事,因為它們存在許多實現和演算法方面的差異。舉個例子,線上演算法所需的計算量往往大於離線演算法,因為它需要取樣和訓練另一個模型。因此,為了比較公平,需要在衡量效能時對不同演算法所耗費的預算進行一定的校準。
在 DeepMind 的這項研究中,研究團隊在比較時並未將計算量作為一個優先考慮因素,而是採用了 Gao et al. (2023) 的論文《Scaling laws for reward model overoptimization》中的設定:使用 RLHF 策略和參考 SFT 策略之間的 KL 散度作為預算的衡量指標。
在不同的演算法和超引數設定中,KL 散度是以一種統一的方式衡量 RLHF 策略與 SFT 策略的偏離程度,從而能以一種經過校準的方式對演算法進行比較。
基於古德哈特定律比較線上和離線演算法的效能
首先,該團隊比較了線上和離線演算法的過度最佳化(over-optimization)行為 —— 該行為可透過將古德哈特定律外推至 AI 對齊領域而預測得到。
簡單總結起來,古德哈特定律(Goodhart’s law)可以表述成:一項指標一旦變成了目標,它將不再是個好指標。
該團隊採用了與 Gao et al. (2023) 類似的設定,基於一組開源資料集進行了實驗,結果表明:在同等的最佳化預算(相對於 SFT 策略的 KL 散度)下,線上演算法的效能表現通常優於離線演算法。
圖 1 給出了線上和離線演算法在四個不同的開源資料集上表現出的 KL 散度與策略效能之間的權衡。圖中的每個資料點代表了在訓練過程中某個特定檢查點下,針對特定一組超引數的策略評估結果。
其中,對於線上演算法,超引數並未被大量調整,而是始終使用一組固定的超引數;對於離線演算法,則是將不同超引數的結果池化後得出。可以觀察到如下結果:
符合古德哈特定律的過度最佳化。不管是線上還是離線演算法,效能都會隨 KL 散度先升後降。後期下降的原因是過度最佳化效應,這符合古德哈特定律的預測。
線上演算法能比離線演算法更高效地使用 KL 散度預算。相比於離線演算法,線上演算法似乎通常能實現更好的權衡。具體而言,在 KL 散度度量的預算一樣時,線上演算法得到的效能通常優於離線演算法。在不同的 KL 散度層級上,線上演算法在所有任務上的峰值效能都高於離線演算法。其中,在 OpenAI 摘要和 Anthropic 輔助任務上的峰值效能差異顯著,在另兩個任務上的峰值差異較小。
總之,線上演算法完全勝過離線演算法,這也奠定了後續研究的基礎。
對於線上和離線演算法效能差異的假設
為了更好地理解線上和離線演算法效能差異的根源,該團隊透過假設驗證的形式進行了研究。
也就是說首先提出一些假設,然後驗證它是否正確。先來看看他們提出了怎樣的假設。
假設 1:資料覆蓋情況。線上演算法更優的原因是其覆蓋的資料比離線資料集更多樣化(即隨時間變化取樣自不同的學習器策略)。
假設 2:次優的離線資料集。離線演算法處於劣勢,因為其初始的偏好資料集是由一個次優的策略生成的。如果使用有更高絕對質量的響應訓練離線演算法,則效能會更好。
假設 3:分類能力更好,則效能更好。離線演算法通常是將策略作為分類器進行訓練。但是,作為分類器,它們可能並不如代理偏好模型那樣準確(因為對分類進行引數化的有效方式不同)。如果準確度提升,則其效能也會提升。
假設 4:非對比式損失函式。在這樣的效能差異中,有多大部分可歸因於對比式的損失函式,而不是離線的樣本?
假設 5:擴充套件策略就足夠了。要彌合線上和離線演算法之間的差距,只需提升策略大小就足夠了。
實驗和結果
實驗設定
為了驗證上述假設,該團隊進行了大量對照實驗。
所有實驗都使用 T5X 模型,並搭配了 T5X 資料和計算框架。為了較好地覆蓋 RLHF 問題,他們研究了四種任務:OpenAI 摘要、Anthropic 輔助、聊天競技場、Anthropic 無害性。
圖 2 給出了這些對照實驗的設定情況,其整體上基於 Gao et al., 2023。其中,綠框表示資料集,藍框表示學習到的偏好模型或策略。
圖 3 則給出了線上生成資料集的圖示。這裡的線上演算法主要由代理偏好模型和線上學習的策略之間的互動組成。
該團隊的實驗研究涉及多個維度,其得到的主要結果如下。
資料
該團隊提出的一些假設涉及到離線資料集的性質。其中包括假設離線資料集的覆蓋情況比線上生成的資料集差;假設離線演算法對離線資料集更敏感,而離線資料集中響應的絕對質量要差一些。(圖 4 和圖 5 分別證否了這兩個假設)。
儘管這些假設聽上去似乎是對的,但實驗結果表明它們無法可信地解釋線上和離線演算法的效能差距。
他們透過消融研究發現,提升離線最佳化的一種有效方法是生成分佈上接近起始 RLHF 策略(這裡就剛好是 SFT 策略)的資料,這本質上就模仿了線上演算法的起始階段。
最佳化性質
該團隊發現判別能力和生成能力之間存在一種有趣的相互作用:儘管離線策略的分類能力勝過線上策略,但離線策略生成的響應卻更差(見圖 6、7、8)。
不管是類間分類還是類內分類實驗,分類效能和生成效能之間的關聯似乎都不大。儘管離線和線上取樣都是針對一個判別目標最佳化的,但離線取樣是提升在一個靜態資料集上的分類準確度,而線上取樣則是透過不斷改變取樣分佈來提升生成質量。實驗表明,離線策略的生成效能提升不如線上策略的直接。
損失函式與擴充套件
為了確保所得結果更普適,他們還研究了用於 RLHF 的對比式和非對比式損失函式。
線上與離線效能之間的差距似乎總體上持續存在,儘管這種差異的根本原因可能與演算法有關。他們也研究了效能差距隨策略網路規模擴充套件的變化情況(見圖 10 和 11)。效能差距一直存在這一事實說明:只是擴充套件模型規模可能無法解決取樣問題。
儘管實驗結果暗示了在策略取樣對模型對齊的根本重要性,但這些結果也許有助於揭示離線對齊演算法的實驗內部工作原理,並揭示效能差異的根源。總而言之,這些發現為 RLHF 實踐者提供了有趣的見解和挑戰,併為更有效的 AI 對齊實踐鋪平了道路。
根據現有的強化學習研究成果,線上比離線更好似乎是顯而易見的結論。線上和離線強化學習演算法之間的效能差距也已經被多項研究發現,所以這項研究給出了什麼不一樣的結論呢?
最重要的是,線上 RLHF 演算法依賴於一個學習後的獎勵模型,該獎勵模型是使用與離線 RLHF 演算法一樣的成對偏好資料集訓練得到的。這與常規強化學習設定存在根本性差異 —— 常規強化學習假設能以線上方式獲取基本真值獎勵,在這種情況下,線上強化學習的優勢明顯。假設 RLHF 受到獎勵訊號的瓶頸限制,我們就不清楚線上與離線的差距是否還會這樣顯著。
從更技術性的角度來看,許多 RLHF 演算法採用了上下文賭博機的設計形式,並針對參考策略應用了正則化。這樣的演算法細節讓 RLHF 偏離了常規的強化學習設定,這可能會影響離策略學習問題的嚴重程度。