最近大家都在探討和嘗試復現OpenAI O1的思考效果,解碼出的關鍵技術方向,包括之前已經探討過的Inference Time Scaling在推理過程中進行路徑決策和選擇。但想要更優的Inference Time Scaling曲線,前提是模型本身是一個很強的Generator,已經擁有足夠的生成合理推理過程的能力,同時還擁有很強的Verifier模型來對推理節點進行打分決策,並且二者可以在少人類監督的條件下不斷迭代最佳化。
這一章我們先聊聊如何讓大模型"自學"推理思考,從而得到思考推理能力更強的Generator。本章會以STaR論文為基礎,介紹生成複雜動態思維鏈背後可能的技術方案
STaR
- STaR: Self-Taught Reasoner Bootstrapping ReasoningWith Reasoning
STaR是這一系列論文的第一篇,思路就是妥妥的Bootstrap,生成推理過程->訓練模型->生成更優的推理過程->訓練更強的模型。
STaR的流程很直觀
- Pretrain模型,透過指令+fewshot,引導模型對QA資料集生成推理過程
- 對以上推理過程進行過濾,只保留回答正確的
- 對推理答案錯誤的,透過Hint(在上文中告訴模型正確答案),引導模型生成正確的推理過程,對這部分樣本也進行過濾,只保留回答正確的
- 使用以上樣本進行SFT,教模型如何思考
- 再使用SFT後的模型重複以上樣本生成的過程,直到評估指標不再提升
STaR的優缺點都非常明顯,優點就是不需要大量人工標註的思維鏈樣本,也不依賴更強大的模型提供合成樣本(其他模型提供的合成樣本本身也可能存在分佈漂移會影響模型效果),實現了一定程度的模型自我最佳化提升。缺點有
- 可用場景有限:STaR依賴正確答案作為過濾條件,因此只適用於問答,數學計算等有限領域,對於更廣泛的開放領域無法適用。這個限制其實也是因為STaR並未引入Verifier,因此只能依賴答案本身作為評估基準。
- SFT本身的泛化性有限:透過SFT把生成的推理過程注入模型,很難讓模型學到推理過程中的獎勵訊號,更多還是在做Behaviour Cloning。達不到"Don't Teach, Incentive"的效果
- STaR對樣本的使用率不足,只使用了唯一的一條正確樣本,丟棄了通往正確答案的更多正確路徑,也丟棄了更大量級的錯誤思考過程
- 思考鏈路是靜態,既針對任何問題模型都預設上來就進行思考,這種形式在單一場景中適用,在更靈活廣泛的實際場景中思考應該動態存在
下面我們看下針對以上問題,其他論文給出了哪些最佳化方案,以下論文更多會關注和STaR的對比~
RFT
- Scaling relationship on learning mathematical reasoning with large language models
RFT也是模型自我合成資料進行最佳化的方案,它沒有使用STaR的多輪Bootstrap來持續最佳化合成資料,只用了一輪最佳化,但RFT給出了在一輪迭代內,更充分利用正樣本的方案。
RFT會使用SFT後的模型,針對每個問題隨機取樣100條推理路徑,篩選所有答案正確的推理路徑,並使用編輯距離對不同的推理路徑進行消重,只保留差異化的正確推理路徑。這樣對比以上STaR每個問題只有1條正確樣本,RFT對每個問題會保留多樣性的正確推理路徑,然後使用該合成資料集對模型進行訓練。對比後發現使用更多推理路徑效果會有提升,同時去重也會帶來明顯的效果提升。大機率因為不去重,會導致部分重複樣本的過度擬合,影響泛化性。
RFT這種使用模型自我合成資料再微調基座的方案,在後面Google Deepmind的論文中也進一步論證了它的有效性要超過使用更強大的模型直接合成資料的效果。部分因為多個正確推理路徑的提供,能給模型提供一些哪些推理節點是核心節點的有效資訊,降低模型模仿率,提高模型泛化性。
V-STaR
- V-STaR: Training Verifiers for Self-Taught Reasoners
V-STaR沿用了STaR的多輪Bootstrap樣本迭代的方案,並給出了一種簡單的利用負樣本的方案,在以上STaR的基礎上,每一輪模型生成推理答案時,正確和錯誤的推理鏈路都會被保留,其中正確的樣本用來訓練微調Generator,而正確和錯誤的樣本會合並用於訓練Verifier。
以及和STaR每一輪都只使用新訓練的Generator合成的樣本不同,這裡訓練Verifier的樣本是每一輪收集樣本的並集。因為RM模型需要廣泛學習不同分佈的推理結果,而每一輪隨著Generator不斷增強,其實都在拓寬RM模型學習的樣本範圍,提升Verifier的泛化性。
最後論文用收集好的正負樣本,構建了針對問題的對比樣本對(x, y+,y-) ,然後使用DPO在最後一輪微調得到的最優的Generator上來訓練Verifier。並在推理過程中使用該Verifier,來實現best-of-n策略,從N個隨機取樣的推理結果中選擇RM得分最高的推理鏈路。
效果上加入Verifier的STaR效果會有進一步提升,並且多輪Bootstrap也能有效提高V-STaR的效果。
Incorrect Synthetic Data
- RL on Incorrect Synthetic Data Scales the Efficiency of LLM Math Reasoning by Eight-Fold
GDM這篇論文對正負合成思維鏈樣本都做了更加全面的討論,基本結論如下
- 正樣本:論文論證了前面RFT,也就是使用微調模型自我生成推理鏈路的方案,要優於使用更強模型直接生成樣本進行SFT。但是隻使用合成正樣本做SFT,因為無法保證鏈路的完全正確,會讓模型學到一些混淆的錯誤思考模式。
- 負樣本:對比V-STaR只在Verifier中簡單利用了負樣本,論文給出了在最佳化Generator中使用負樣本的訓練方案
下面我們分正負樣本來分別說下~
正樣本:為何自我生成的正樣本效果更好?
論文分別採用兩種方案來合成資料
- SFT:使用更強大的模型合成資料,例如GPT4來生成帶有思維鏈的推理樣本,經過簡單的消重,過濾錯誤答案後,使用正確樣本直接微調模型
- RFT:模型自我合成資料,使用以上微調後的模型,針對每個問題再生成N個推理結果,經過過濾後使用正確的樣本微調模型,也就是使用基座微調模型自我生成的樣本再回來微調基座
論文發現在Deepseek和Llama2上,隨著合成資料集的數量變大,RFT顯著優於SFT,並且優勢並不隨資料集變大而縮小。具體到資料使用效率,相同的Test Error下,使用RFT策略訓練的效果相當於使用2倍的合成資料進行SFT
這個結論會有一些反直覺,因為之前很多最佳化小模型的思路都是去蒸餾GPT4的回答。當然後面也有一些研究認為擬合另一個模型的回答,因為預訓練的差異,導致微調過程中模型很難直接學習新的推理回答只能強行記憶,影響模型泛化效果。 類似的問題其實在早期我們也用GPT3.5,GPT4的回答去構建樣本,然後微調一些小模型的時候就發現了,當回答風格差異巨大的時候,直接微調,會影響基座本身的知識儲存和指令理解。其實就是小模型為了去強行改變自己的輸出風格,負向影響了模型本身的引數分佈。
論文使用RFT生成的樣本,相比SFT樣本,在基座模型上有更高的log likelihood來論證之所以使用RFT的樣本微調效果更好,就是因為RFT樣本是基座模型自我合成的,因此和基座模型本身的推理分佈更加接近,模型更好學習,會降低模型去強行記憶的機率,對泛化性的損失更小,更加“easy-to-fit”。
但不論是SFT還是RFT,論文提出都需要關注正確樣本中錯誤的推理鏈路,因為樣本過濾只使用了答案,並未對中間推理鏈路的正確性進行校驗,而這些錯誤的步驟,會導致模型學到一些混淆的因果關係。而虛假步驟帶來的推理問題,並無法透過簡單的增加合成資料的方法來解決。
下面我們接著看論文如何透過引入負樣本和per-step DPO來最佳化合成樣本中錯誤步驟帶來的問題。
負樣本:呦呵你沒想到我也這麼有用吧
既然同一個問題生成多條正向的推理鏈路的合成樣本可以提升效果,那如何更有效的利用比正樣本佔比更高的負樣本呢?前面V-STaR是選擇利用負樣本去訓練Verifier,而GDM的論文給出了透過正負樣本對比學習來充分利用負樣本的方案。論文設計的RL目標函式如下,透過正負樣本分別和基準(微調後的基座模型)模型對比,來進行對齊。
並且論文給出了從“關鍵步驟”這個概念出發構建正負樣本對的方案,那啥叫關鍵步驟嘞?
可以從熵值的視角去看,如果生成步驟A後,模型得到正確答案,或者錯誤答案的機率顯著更高,那步驟A就是關鍵步驟。其中通往錯誤的核心步驟需要模型遺忘,透過正確的核心步驟需要學習。
那如果生成步驟A後,模型得到正確和錯誤答案的機率一半一半,那步驟A就不是關鍵步驟。想要獲得每個步驟通往正確、錯誤答案的機率,其實只需要透過蒙特卡洛模擬取樣足夠多的鏈路,然後做個機率統計就行嘞
以上的關鍵價值,論文用以下的公式來計算,每個步驟(i)的價值(Q value),都是給定(1i-1)的步驟,計算步驟模型在未來(i+1L)步內獲得正確答案的期望機率。以上價值其實是步驟(1~i)的累計價值,而每個步驟的增量價值,就是和截止前一步Q value的差值。
所以構建正負推理鏈路的步驟,就是基於每個問題,使用微調後的基準模型取樣N條推理鏈路,基於答案是否正確,先初步分成正負樣本。然後在負樣本內,尋找每個推理鏈路中關鍵價值最低的第一個步驟(大機率導致模型失敗的第一個核心節點)作為\(-y_{1:c}\)。在正樣本內,尋找每個推理鏈路中關鍵價值的第一個步驟作為\(+y_{1:c}\),然後基於這些正負樣本對,使用以上對比RL損失函式進行訓練。
效果上對比SFT,per-step DPO訓練的效果,會發現相同Test-Error下,使用per-step DPO訓練,相當於使用8倍的合成資料進行SFT, 也就是per-step DPO的合成資料使用效率比SFT要高8倍 對比前面RFT,利用負樣本,識別關鍵步驟進行對比學習,對合成資料有更高的利用率~
Quiet-STaR
- Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking
- THINK BEFORE YOU SPEAK: TRAINING LANGUAGE MODELS WITH PAUSE TOKENS
- https://github.com/ezelikman/quiet-star?tab=readme-ov-file
Quiet-STaR擴充了以上STaR們的應用範圍,針對開放世界的各類問題,給出了讓模型先給出多條思考路徑,再在推理時動態決策多大程度上依賴前置思考來回答問題的新方案。
Quiet-STaR的實現包括3步,我們對照的著下面的程式碼來說
- N路平行思考生成:針對每個輸入token,生成N個推理鏈路,每個推理鏈路的長度都是固定的,並且在推理鏈路的開始和結束,分別插入新的Special token <startofthought>,<endofthought>。 這些推理鏈路就是模型的隱藏思考步驟。對應以下程式碼中continuation迭代生成的部分。
- 混合權重頭:論文在原生LLM的基礎上新增了一個分類頭,就是簡單一層或多層MLP。分類頭用於進行權重混合,也就是在預測真實的下一個token的時候,根據分類頭的權重,決定多少資訊來自上一個輸入token,多少資訊來自新生成的隱藏思考步驟。對應以下程式碼中的mixing_weight和mixed_hidden_states。
- RL訓練最佳化思考生成:最後透過強化學習,在訓練以上<startofthought>,<endofthought>,MLP分類層的基礎上,讓模型生成的隱藏思考步驟更加有用,可以提高模型推理效果。
# Append the start thought token to the input sequence
start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
seq_len += 1
# Update the attention mask
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
# Generate the continuation
continuation_length = self.n_ahead - 2
new_key_values = past_key_values
start_time = time.time()
for continuation_idx in range(continuation_length):
outputs = self.model(
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=new_key_values,
inputs_embeds=inputs_embeds,
use_cache=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
new_key_values = outputs.past_key_values
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits[:, -1, :] # Only consider the last token
# Apply Gumbel-Softmax to the logits
next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
next_token_id = torch.argmax(next_token_logits, dim=-1)
# Append the generated token to the input sequence
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
seq_len += 1
# Update the attention mask
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
# Append the end thought token to the input sequence
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
seq_len += 1
# Update the attention mask
if attention_mask is not None:
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
# Get the hidden states before and after the thought
outputs_before = self.model(
input_ids=original_input_ids,
attention_mask=original_attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states_before = outputs_before[0][:, -1:, :]
# two new tokens: last continuation token and end thought token
outputs_after = self.model(
input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1),
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=new_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states_after = outputs_after[0][:, -1:, :]
# Apply the talk head to get the mixing weight
mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
# Apply the mixing weight to the hidden states
mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
# Apply the language model head to get the final logits
logits = self.lm_head(mixed_hidden_states)
按照以上的步驟,我們來舉個例子,輸入是“今天天氣“:
1. 平行思考生成(N=5)
- 思考1:<startofthought>是否下雨、下雪等異常天氣<endofthought>
- 思考2:<startofthought>溫度、溼度情況<endofthought>
- 思考3:<startofthought>溼度情況<endofthought>
- 思考4:<startofthought>對人們生活活動的影響<endofthought>
- 思考5:<startofthought>穿衣建議<endofthought>
2. 權重混合(思考1舉例)
- 無思考時輸入的next-token:很(0.4)+ 真(0.6)
- 思考1後面的next-token:晴朗(0.2) + 下雨(0.2) + ....各種天氣情況
- 思考1混合後(混合權重0.5):很(0.2) + 真(0.3)+晴朗(0.1) +下雨(0.1)+ ....各種天氣情況*0.5
3. RL訓練(如果原始文字後面是“晴朗”)
- RL給予“思考1”以正向梯度更新,在天氣後面思考溫度有助於模型推理
論文給出的一個實際的推理效果如下
再來說下訓練的部分,論文采用了強化學習來訓練新加入的MLP混合頭,思考開始和結束的token,以及基座模型本身。因為前面再每個位置生成了N個不同的思考路徑,論文選擇的RL訓練目標是最佳化每個token位置生成正確思考路徑,相對其他思考路徑對推理下一個token帶來的增量資訊。
然後鼓勵模型去生成對推理下一個token有幫助作用的思考路徑(Tj)
同時還會增加一個NLI Loss,來訓練用於混合思考和推理權重的MLP頭。訓練對比前面STaR等策略都是使用QA等特定領域指令樣本,Quiet-STaR選擇OpenWebMath(技術網站爬蟲資料)進行訓練,考慮技術類文字依賴思考的情況更多,模型訓練得到的正向訊號會更加密集。因為訓練本身是預訓練的Next-Token-Prediction,因此對比STaR具有更好的泛化效果,可以不限領域,不限任務進行訓練。
Quiet-STaR還有待進一步最佳化的問題包括
- 動態思考位置的選擇:Quiet-STaR是在每個位置都生成N個思考鏈路後,再使用mix-head來對每個位置的思維鏈和原始推理進行權重融合,屬於後選擇方案,推理成本較高,如果能根據輸入本身進行前置的思考位置選擇,只在最優的一個或幾個位置上進行內生思考推理(MCTS)就更完美了
- 模型內容思考可能本身不可解釋,因為Quiet-STaR只在HighLevel層面去最佳化加入內生思考後,模型推理效果的提升,並未對思考本身的next-token prediction進行對齊,導致生成的思考本身甚至可能並不在語言上通順。當然因為本身是在訓練後的基座模型上推理,所以肯定保留了部分的語言邏輯性
- 模型內生思考可能存在各種3H(helpful,harmless,honesty)問題。同樣是對齊問題,模型生成的思考鏈路不僅未在語言模型角度對齊,也未在人類偏好角度對齊,這可能也是OpenAI在O1中考慮對使用者隱藏內在思考鏈路的原因之一。而對齊本身是否會影響內生思考的效果需要額外的實驗驗證。
Quiet-STaR和OpenAI O1在生成模型內生思考上的技術棧是很像的。OpenAI在O1的使用說明Link中也指出,O1是透過動態插入思考token,來生成內生思考,並基於內生思考進行推理回答,思考對使用者不可見(OpenAI在Learning to Reason with LLMs中也說明隱藏思維鏈的部分是未對齊的),只展示回答部分。而多輪對話的上文也只會使用輸入輸出不會使用內生回答。使用感受上在金融場景下,一些強數字,強邏輯的問題例如表格問答,財務問題分析上O1有比較顯著的效果提升。
想看更全的大模型論文·微調預訓練資料·開源框架·AIGC應用 >> DecryPrompt
OpenAI O1技術路線解析的一些好文推薦~
- OpenAI Learning to Reason with LLMs
- 北大對齊團隊獨家解讀:OpenAI o1開啟「後訓練」時代強化學習新正規化
- Reverse engineering OpenAI’s o1
- OpenAI’s Strawberry, LM self-talk, inference scaling laws, and spending more on inference
- OpenAI o1 self-play RL 技術路線推演
- 讓 LLM 下一盤大棋:RL 正規化探討