編輯 | ScienceAI
OpenAI o1、DeepSeek R1 等模型成功實現了在數學、程式設計等領域的智慧慢思考推理,透過自我反思和修正實現了執行時的效能外推。
然而,在醫療領域,仍然很少有模型可以實現具有長鏈慢思考的推理。目前醫療領域的推理模型大多是透過在醫療考試題上對 OpenAI 系列的模型進行蒸餾,並沒有考慮推理過程的可驗證性,以及醫療任務的覆蓋度。
為了解決這些問題,上海交通大學人工智慧學院、復旦大學和上海人工智慧實驗室的團隊開發了一種新型醫學推理系統——MedS3。
該系統採用自我進化的「慢思考」正規化,無需預訓練和模型蒸餾,能夠對推理流程的每一步進行細粒度驗證。
論文連結:https://arxiv.org/pdf/2501.12051
專案主頁:https://pixas.github.io/MedS3-pages/
論文標題:MedS3: Towards Medical Small Language Models with Self-Evolved Slow Thinking
MedS3 由策略模型(Policy Model)和過程獎勵模型(Process Reward Model; PRM)組成,透過在 16 種不同資料集上的學習,包括醫療診斷、生物醫學和知識性問答等。
僅使用 7465 條種子資料,結合細粒度的蒙特卡洛樹搜尋和規則驗證的過程監督訊號,MedS3 迭代最佳化策略模型和過程獎勵模型。
評估結果顯示,MedS3 在醫療知識問答、生物醫學問答、長上下文問答和醫療診斷任務上的推理能力顯著超越現有醫療大模型和通用域推理模型,成為首個在醫療診斷任務上實現長鏈推理「R1」的大語言模型框架。
研究動機
以往的醫療模型訓練面臨醫療語料匱乏的問題,通常有兩種解決方案:
(1)在大規模人工收集篩選的醫療語料上進行預訓練;
(2)在少量特定任務資料集上進行有監督微調。然而,第一種方法消耗大量計算資源,但下游任務效能提升有限;第二種方法雖計算高效,但微調資料多為閉源模型生成的蒸餾資料或人工標註的短回覆資料,限制了模型的最佳化空間和跨任務泛化能力。
系統框架
為了解決醫療模型的資料困境,MedS3 轉向執行時縮放(test-time scaling),以一種資料高效的後訓練方法進行提升,從而突破資料集標註的約束,在平衡計算資源與效能之間的矛盾下,高效利用現有的醫療資料。
MedS3 的核心在於其獨特的自我進化框架。研究者首先利用蒙特卡洛樹搜尋(MCTS)技術,基於基礎策略模型生成可驗證的推理鏈。在推理鏈的每一步,都會基於這一步的正確性賦予一個展開值,透過這些經過驗證的軌跡來訓練策略模型和過程獎勵模型(PRM)。
這種搜尋對計算資源的依賴極小,透過策略模型演化得到的正負樣本均可以作為 MedS3 的監督訊號,大大增加了資料利用率,並且按步取樣也能提升模型的探索空間。
在推理過程中,策略模型會生成多個回答,而獎勵模型則透過新提出的 PRM 引導的投票求和(P-VS)策略來選擇最終答案。
這種策略不僅考慮了 PRM 對每個回覆的評判結果,也考慮了不同回覆之間的語義一致性。這種自我進化的方式,不僅提高了模型的資料效率,還使其在多種臨床任務中展現出了卓越的推理能力。
圖 1 MedS3 框架的構建過程。
圖 2 PRM 引導的投票求和計算示例。
MedS3 的優勢在於:
- 資料利用率高:使用自我啟發式搜尋,擴大了資料的表徵範圍和利用率。
- 支援單步監督:搜尋中的進化展開值可以為單步推理提供監督,從而規劃正確推理軌跡。
- 高效支援多工學習:對每個資料集取樣約 500 條資料即可實現多工同時學習。
圖 3 種子資料集中各任務及設計的資料集。
實驗結論
MedS3 的任務同樣涵蓋了來自不同任務的 11 個資料集,涵蓋了知識問答、生物醫學問答、長上下文問答、醫療語義推理以及醫療診斷式問答。
實驗 1:同時領先醫療開源模型、通用推理模型
醫療模型的效能提升普遍較小,最大提升僅為 6.48(MMedS-Ins vs Llama3 8B),且多數模型因任務覆蓋不全,難以超越通用的 Llama3 8B。
通用域強化推理能力的開源模型因缺乏醫療知識,效果受限,最佳表現 R1-Distill-Qwen32B 僅提升 4.36。
相比之下,MedS3 透過融合推理能力和醫療知識,並採用創新的 PRM 引導投票求和方法,相比 Llama3 8B 提升了 13.07,顯著超越所有同等規模的開源模型,並在綜合效能上領先更大規模的模型。
實驗 2:P-VS 選擇策略領先以往的 PRM 使用方法
傳統 Best-of-N(BoN)方法依賴 PRM 篩選最優回覆,但 PRM 存在訓練不穩定性和訓練/真實標籤偏差問題。P-VS 創新性融合語義一致性校驗與 PRM 評分,突破單一依賴瓶頸,實現 3.46 的效能躍升。
實驗 3:幾乎無界的效能外推
推理模型透過增加詞元消耗,幾乎無損提升效能,MedS3 也具備此特性。
透過在 5 個診斷資料集上分別設定取樣條數為 2、4、8、12、16,採用 P-VS 策略評估其效能,結果顯示,除了 Pubhealth 資料集效能收斂外,其餘資料集呈現出幾乎無界的效能提升,驗證了 MedS3 強大的效能外推潛力。
圖 4 MedS3 詞元消耗量和效能的關係。
實驗 4:MCTS+PRM 仍是實現醫療推理的最有效方案
DeepSeek R1 證實強化學習(RL)可提升模型推理能力,但研究者們的實驗表示,醫療領域中,基於蒸餾的小模型泛化性不足,傳統 RL 方法效能仍弱於 MCTS+PRM 正規化(如 MedS3 所示)。
醫療場景的特殊性使 MCTS+PRM 優勢顯著:醫學指南的強結構化特徵天然適配過程監督需求——策略模型可精準劃分診療步驟,PRM 能有效完成分步評估,規避其常規場景下的步驟劃分難題。
值得注意的是,MCTS+PRM 與 RL 具備互補性,聯合應用可進一步提升模型泛化能力。
結論和展望
這篇工作釋出了涵蓋多個醫療任務的推理系統 MedS3,透過蒙特卡洛樹搜尋訓練了一個策略模型和一個過程監督模型,是醫療推理模型進展上的一個重要工作。
隨著 DeepSeek R1 的釋出,強化學習以其高泛化性和高資料利用率成為通用域廣泛使用的方案。如何將強化學習的思想融入到醫療推理,使其能有效和過程監督結合,仍然是值得思考的一個問題。