Meta開發System 2蒸餾技術,Llama 2對話模型任務準確率接近100%

机器之心發表於2024-07-15
研究者表示,如果 Sytem 2 蒸餾可以成為未來持續學習 AI 系統的重要特徵,則可以進一步提升 System 2 表現不那麼好的推理任務的效能。

談到大語言模型(LLM)的策略,一般來說有兩種,一種是即時的 System 1(快速反應),另一種是 System 2(慢速思考)。

其中 System 2 推理傾向於深思熟慮的思維,生成中間思維允許模型(或人類)進行推理和規劃,以便成功完成任務或響應指令。在 System 2 推理中,需要付出努力的心理活動,尤其是在 System 1(更自動化思維)可能出錯的情況下。

因此,System 1 被定義為 Transformer 的應用,可以根據輸入直接生成響應,而無需生成中間 token。Sytem 2 被定義為生成中間 token 的任何方法,包括執行搜尋或多次提示然後最終生成響應的方法。

業界已經提出了一系列相關的 System 2 技術,包括思維鏈、思維樹、思維圖、分支解決合併、System 2 Attention、Rephrase and Respond (RaR) 等。得益於這種明確的推理,許多方法都顯示出更準確的結果, 但這樣做通常會帶來更高的推理成本和響應延遲。因此,許多此類方法未在生產系統中使用,而大多使用了 System 1。

對於人類來說, 學習將技能從深思熟慮(System 2)轉移到自動(System 1)的過程在心理學中被稱為自動性,以及程式記憶的使用。例如,第一次開車上班時,人們通常會花費有意識的努力來計劃和做出到達目的地的決定。而在駕駛員重複這條路線後,駕駛過程就會「編譯」到潛意識中。同樣,網球等運動可以成為「第二天性」。

在本文中,來自 Meta FAIR 的研究者探索了一種類似的 AI 模型方法。該方法在給定一組未標記示例的情況下以無監督的方式執行編譯,被稱為 System 2 蒸餾。對於每個示例,他們應用給定的 System 2 方法,然後以無監督的方式測量預測的質量。

例如對於具有唯一答案的任務,研究者應用自洽性(self-consistency)並多次進行取樣。對於 System 2 足夠一致的示例,他們假設應該蒸餾此結果,並將其新增到蒸餾池中。然後對 System 1 進行微調,以匹配 System 2 方法對收集的示例池的預測,但不生成中間步驟。下圖 1 說明了將 System 2 蒸餾到 System 1 的整體過程。

圖片

研究者對 4 種不同的 System 2 LLM 方法和 5 種不同的任務進行了實驗。結果發現,本文方法可以在各種設定中將 System 2 推理蒸餾回 System 1 中,有時甚至比 System 2 教師的結果更好。此外,這些預測現在只需花費計算成本的一小部分即可產生。

例如,他們發現成功的蒸餾適用於處理有偏見的意見或不相關資訊的任務(System 2 Attention)、澄清和改進某些推理任務中的響應(RaR)以及 LLM 的細粒度評估(分支 - 解決 - 合併)。

不過,並非所有的任務都可以蒸餾到 System 1 中,尤其是需要思維鏈的複雜數學推理任務。這也反映在人類身上,如果沒有深思熟慮的 System 2 推理,人類就無法執行某些任務。

圖片

論文地址:https://arxiv.org/pdf/2407.06023v2

將 System 2 蒸餾回 System 1

設定:System 1 和 System 2 模型

給定一個輸入 x,研究者考慮設定一個單一模型,在他們的例子中是一個大語言模型 (LLM),它能夠實現兩種響應模式:

  • System 1:直接生成輸出 y。這類方法透過轉發(forwarding)底層自迴歸神經網路 (Transformer) 的各層來生成輸出標記來完成。

  • System 2。這類方法使用底層 Transformer 在生成最終響應 token 之前生成任何型別的中間輸出標記 z,可能包括多次呼叫(提示)。

從形式上,研究者將 System 2 模型 S_II 視為一個函式,它接受 LLM p_θ 和輸入 x,並且可以重複呼叫 LLM 以使用特定演算法生成中間標記 z,然後返回輸出 y:

圖片

System 2 方法可能涉及多個提示、分支、迭代和搜尋,同時使用 LLM 生成中間結果以供進一步處理。相比之下,System 1 模型僅考慮原始輸入 x 並直接呼叫 LLM pθ 來生成輸出 y:

圖片

方法:System 2 蒸餾

本文方法的第一步是使用 System 2 模型對未標記的輸入 X 生成響應:

圖片

然後,這些響應 y^i_S_II 可直接用作 System 2 蒸餾目標,以微調 System 1 模型。但是,它們容易受到噪聲的影響:其中一些響應可能是高質量的,而另一些可能是低質量或不正確的。對於涉及簡短響應(通常具有唯一正確(但未知)的答案)的簡短問答和推理任務,研究者考慮採用無監督管理步驟來嘗試提高訓練資料質量。他們考慮了以下兩種依賴於自洽性標準的變體:

  • 輸出的自洽性:對 S_II (x^i ; p_θ) 進行總共 N 次取樣,並接受多數投票響應;如果沒有多數投票獲勝者,則丟棄該示例。

  • 輸入擾動下的自洽性:以輸出不變的方式擾動輸入 x^i,例如改變提示中多項選擇題的順序,並計算每次擾動的 S_II;如果輸出不一致,則丟棄該示例。

之後研究者得到了合成資料集 (X_S_II , Y_S_II),其中 X_S_II 是 X 的一個過濾子集,目標是 Y_S_II。最後一步是使用這個蒸餾出來的訓練集對引數為 p_θ 的 LLM 進行監督微調。研究者通常從當前狀態 p_θ 初始化此模型,然後繼續使用新資料集進行訓練。微調後,他們得到一個 LLM 圖片,這是一個 System 1 模型,預計可提供與評估的 System 2 模型類似的輸出和效能提升。

實驗結果

訓練和評估設定

研究者使用 Llama-2-70B-chat 作為所有實驗的基礎模型。他們需要一個具有足夠能力的基礎模型,使其能夠像 System 2 模型一樣高效執行,同時還具有可以微調的開放權重,因此做出了此選擇。

同時,研究者考慮了幾種 System 2 方法,包括 System 2 Attention、 RaR、分支解決合併(Branch-Solve-Merge)和思維鏈, 並重點關注每種方法都顯示出強大效能的任務。

對於 System 1,研究者使用指令調整後的基礎模型作為標準基線進行零樣本推理。他們報告每個任務的任務特定指標,以及「#Tokens」指標,後者衡量評估集上每個輸入生成的平均 token 數量。System 2 方法則包括中間 token 生成以及最終輸出 token 生成。

Rephrase and Respond 蒸餾

RaR 是一種 System 2 方法,它首先提示語言模型以進一步闡述的方式來複述原始問題,然後基於複述的問題生成響應,目的是提供更優的輸出。

對於蒸餾資料,研究者使用輸出的自洽性為 RaR 構建 System 2 蒸餾資料集。對於每個輸入,他們對最後一個字母( last letter)任務進行了八次取樣迭代,並同樣對硬幣翻轉(coin flip)任務的每個階段進行八次取樣迭代,然後用多數投票來確定最終輸出。

首先來看最後一個字母連線(Last letter Concatenation)任務。此任務側重於符號推理,要求模型連線給定單詞的最後一個字母。整體結果如下表 1 所示。

基線 System 1 模型 (Llama-2-70B-chat) 的準確率達到 30.0%,低於 System 2 的 1-Step 和 2-Step RaR 方法(分別為 39.5% 和 44.5%)。透過本文無監督技術將 2-Step RaR 方法蒸餾回 System 1 Llama-2-70B-chat 模型,則實現了 98.0% 的驚人準確率

與零樣本聊天模型相比,模型可以有效地從這些訓練資料中學習如何解決任務。RaR 的蒸餾有效地繼承了 System 2 和 System 1 的優勢,既保留了 System 2 的準確率優勢,而其推理成本與 System 1 相當。

圖片

再來看硬幣翻轉推理任務。這種符號推理任務經常在研究中進行測試,它涉及確定硬幣的最終面(正面或反面),從已知的初始位置開始,經過一系列用自然語言描述的翻轉,例如「硬幣正面朝上」。

整體結果見上表 1。Llama-2-70B-chat(零樣板)在此任務上的成功率為 56.1%,而 1-Step 和 2-Step RaR 的成功率分別為 58.5% 和 77.2%。因此,使用 2-Step 方法獲得了巨大改進。透過本文無監督技術將 2-Step RaR 蒸餾回 System 1 Llama-2-70B-chat 可以獲得 75.69% 的結果。

因此,蒸餾的 System 2 模型提供的效能與 System 2(2 Step RaR)相當,但不需要使用 2 個提示執行 LLM 程式。

System 2 Attention 蒸餾

Weston 和 Sukhbaatar (2023) 提出了 System 2 Attention (S2A),這種方法有助於減少模型的推理陷阱,例如依賴輸入中的偏見資訊或關注不相關的上下文。

研究者驗證了將 S2A 提煉到 System 1 中的可行性,特別是 SycophancyEval 問答任務,該任務包含已知會損害 LLM 效能的輸入中的偏見資訊。

結果如下表 2 所示,報告了 3 個隨機種子的平均準確率。正如預期,基線(System1)LLM 在有偏見部分的準確率較低,容易受到有偏見輸入的影響。S2A 顯著提高了有偏見輸入的效能。System 2 蒸餾表現出與 System 2 方法類似的強大效能。

圖片

更多實驗結果請參閱原論文。

相關文章