交大O1醫療探索:延長AI思考時間,解鎖複雜推理診斷

ScienceAI發表於2025-01-15

圖片

編輯 | ScienceAI

當醫生面對複雜病例時,往往需要反覆思考、權衡多種可能性,才能得出準確診斷。以鑑別診斷為例,它要求醫生生成可能的診斷列表,並透過評估臨床發現,逐步排除不符合條件的選項。

如今,AI 也學會了這種「深思熟慮」的診斷方式。上海交通大學最新研究發現,給 AI 更多「思考時間」,能顯著提升其醫療診斷能力,讓 AI 更接近專業醫生的診斷水平。

上海交通大學近日釋出了 O1 復現專案系列研究的第三部分成果。

這項由 SPIRAL 實驗室與生成式 AI 研究實驗室(GAIR)聯合完成的研究表明,透過延長AI的推理時間,僅需 500 個樣本訓練,就能讓模型在醫療診斷準確率上提升 6%-11%。

在實際測試中,改進後的 AI 系統能夠像專業醫生一樣,系統性地分析症狀、評估證據,逐步縮小診斷範圍,最終得出合理結論。

「這就像是讓 AI 學會了醫生看診時的思維方式。」專案負責人表示,「在面對複雜病例時,AI 不再僅僅依靠快速匹配,而是能夠進行更深入的分析和推理。這種方法在 JAMA 臨床挑戰等真實醫療場景測試中取得了令人振奮的效果。」

研究還揭示了一個有趣的發現:越是複雜的醫療問題,AI 就需要更長的推理鏈來得出準確結論。這與人類醫生的診斷過程驚人地相似,為提升 AI 在臨床實踐中的應用提供了全新思路。

該研究是繼 Journey Learning 和知識蒸餾研究之後的最新突破,進一步推進了 O1 在專業領域的應用探索。為促進醫療 AI 的開放發展,研究團隊已將所有程式碼和資料集在 GitHub 上開源。

圖片

  • 技術文件:http://arxiv.org/abs/2501.06458

  • 相關資源將近日公開:https://github.com/GAIR-NLP/O1-Journey , https://github.com/SPIRAL-MED/Ophiuchus

探索過程

透過對現有案例的分析,可以發現:隨著問題難度的增加,推理時間(inference-time)往往會按比例增加。這表明更高難度的問題需要更多的推理步驟,這反過來也需要更長的推理時間。

推理時間的擴充套件在識別和分析關鍵資訊方面貢獻顯著,這一現象在醫學領域尤為重要,因為臨床醫生需要花費大量時間處理來自多種來源和模態的資料,以診斷病情、進行預後評估和確定治療方案。

為了證明推理時擴充套件(inference-time scaling)在解決醫學問題中的有效性, 團隊選擇了在先前工作提出的三個基準資料集:JAMA 臨床挑戰(JAMA)、Medbullets 和 MedQA。這些基準測試包含來自多個醫學領域的複雜真實臨床案例以及不同難度級別的醫學執業考試題目。

  • JAMA 資料集:包含從 2013 年 7 月至 2023 年 10 月 JAMA Network Clinical Challenge 檔案中收集的 1,524 個案例,涵蓋 13 個醫學領域。這些案例涉及複雜的臨床場景,包括患者病史、家族病史、實驗室結果、物理檢查、影像分析等,因此需要更復雜的理解和推理才能得出正確的診斷。為評估推理時擴充套件在複雜任務中的有效性,團隊選擇了 o1-mini 模型難以應對的 646 個案例進行評估。

  • Medbullets 和 MedQA 資料集:基於美國國家醫學委員會考試(USMLE)的題目。


    • Medbullets:是一個線上醫學學習平臺,包含 Step 2 和 Step 3 級別的題目,這些題目更強調臨床知識和推理,而不是依賴於課本知識。

    • MedQA:包含部分來自 Medbullets 網站的題目,但不包括詳細解釋。

在當前階段,團隊的主要目標是評估推理時擴充套件(inference-time scaling)在解決醫學問題中的作用。在資訊和資源有限的情況下,團隊沒有選擇直接嘗試直接執行鑑別診斷(differential diagnosis)這一極其困難的任務。

在現實場景中,鑑別診斷符合假設演繹法(hypothetico-deductive method)的原則,即將潛在的疾病或病症視為假設,供臨床醫生評估其有效性。

為了簡化任務,當前部分採用了多項選擇資料集,透過預定義的潛在診斷(即「鑑別」)來指導模型生成假設。團隊沒有選擇直接使用私有資料,因為現實中的臨床場景通常包含大量無關資訊,這些資訊可能干擾推理過程,對當前模型構成巨大挑戰。相比之下,公共基準測試簡化了問題,並消除了部分干擾。同時,分析選擇題選項以確定最終答案的過程與臨床診斷中的思維過程高度相似。

研究團隊在先前的 O1-Journey (Part1 和 Part2) 中驗證了長思維鏈資料對於複雜推理的重要性,並且在構造長思維鏈資料(journey learning)上面取得了一定的成果。

為了使大語言模型在解決醫學問題時能夠進行「深度」的思考,團隊在 Part1 和 Part2 的基礎上,構造用於解決醫療領域中複雜推理問題的長思維鏈資料。參照 Part2 的方式,團隊採用知識蒸餾的方式,使用了 o1 模型生成的高質量資料。生成兩種型別的長思維鏈資料可以分為兩類:

  1. LongStep:提取 o1 模型的解決步驟,訓練 LLMs 模仿這一行為,生成更詳細的解決方案。

  2. LongMonolog:設計提示使 o1-preview 模型將其總結的思路擴充套件為長形式推理,以模擬「內心獨白」風格的詳細解決過程。

為了進一步最佳化資料,團隊對合成的資料進行了過濾篩選以確保質量,同時規範化了格式輸出。在選擇訓練資料樣本時,團隊著重關注問題解決過程的長度,排除了推理過程較短的案例。最終構建了一個包含 500 個樣本的訓練資料集,其中 350 個樣本來自 MedQA 的訓練集,150 個樣本來自 JAMA。

實驗結果

考慮到解決醫學問題需要模型在醫學領域具備良好的基礎能力,團隊選用了 Qwen2.5-32B-Instruct、Qwen2.5-72B-Instruct 以及 LLama3.1-70B-Instruct 作為開展實驗的基礎模型。

圖片

圖示:各個模型在三個醫學資料集上的效能比較。Mean Acc. 表示三個資料集的加權平均準確率。標註為橙色的行(帶有「CoT SFT」字尾)展示了使用 GPT-4o 的原始 CoT 進行訓練的模型結果。標註為藍色的行對應於在 Journey Learning 資料(帶有「-LongStep」和「-LongMonolog」字尾)上微調的模型結果。

團隊展示了各種方法在評估基準測試上的綜合效能比較,包括專有 API、開源基線模型,以及採用構造的 Journey Learning 資料進行微調的多種模型。

為了反映推理時擴充套件(inference-time scaling)的有效性,團隊同時比較了各個模型平均輸出 Token 的數量。

結果表明,更多推理時間帶來更好的效能。例如,當 Qwen2.5-72B 透過逐步推理(無論是 Vanilla CoT 還是 CoT SFT)進行推理時,輸出的 token 長度範圍在 300 到 500 之間,導致平均準確率增加約 5%。相比之下,在利用 Journey Learning 資料進行微調的(如 LongStep 和 LongMonolog),輸出 token 長度延長至約 1000,效能改進約為 10%,這一趨勢同樣體現在 Qwen2.5-32B 和 LLama3.1-70B。

圖片

圖示:對 Qwen2.5-72B-Instruct、LLama3.1-70B 和 Qwen2.5-32B 使用不同策略的在三個資料集上效能比較。

為了直觀地說明推理時間計算的貢獻,團隊展示了 Qwen2.5-72B、LLama3.1-70B Qwen2.5-32B 在三種基準資料集上的準確率,使用了不同策略,Vanilla、Vanilla CoT、CoT SFT、LongStep SFT 以及 LongMonolog SFT。每種策略都顯著提高了總體準確率。特別是,對於 Qwen2.5-72B,不同策略帶來了以下改進:

  • Vanilla CoT: +3.28%

  • CoT SFT: +5.12%

  • LongStep SFT: +9.69%

  • LongMonolog SFT: +11.36%

發現 1: 多數表決法(Majority Voting)的作用

圖片

圖示:推理時擴充套件在 Qwen2.5-72B 相關模型上的作用。

多數表決法是一種常見的推理時擴充套件(inference-time scaling)策略,透過多次計算的結果進行投票匯總來提高推理質量。團隊在 MedQA 資料集上測試了 Qwen2.5-72B 模型。雖然 Vanilla Qwen2.5-72B 透過多數表決法顯示了穩步效能提升,但提高幅度有限(準確率從 74.31% 增加到 74.63%)。

相比之下,當多數表決法與 CoT 推理 (Vanilla CoT)結合使用時,改進更為顯著。然而,準確率達到頂峰(80.44%),隨後略微下降(79.81%)。Journey Learning 策略(LongStep 和 LongMonolog)也觀察到了類似趨勢,但改進更加明顯。例如:

  • LongStep 透過多數表決法提高了 1.26%;

  • LongMonolog 提高了 1.50%。

結論:儘管多數表決法可以透過聚合多次執行的輸出來最佳化預測,但對於缺乏思考深度的中間步驟,其效果有限。而 Journey Learning 透過細緻的推理過程,更有潛力利用多數表決法來增強效能。

發現 2: LongStep 與 LongMonolog 效能的比較

團隊在比較 LongStep 和 LongMonolog 時,很難確定哪種方式始終具有更高的效能。從當前實驗資料來看,LongMonolog 在 MedbulletsMedQA 資料集上表現出更高的準確率,但在 JAMA 資料集中未能保持優勢。例如,在 JAMA 資料集中,Qwen2.5-32B 在 LongStep模式下的準確率為 56.34%,但在 LongMonolog 模式下僅為 53.71%。

透過觀察輸出示例,團隊發現 Qwen2.5-32B 可能在構建完整推理鏈時存在不足,導致效能下降。過長的推理步驟可以帶來正確答案,而冗餘反思有時會導致錯誤。這表明,儘管推理時間內延長思路鏈條可以幫助回答複雜醫學問題,但前提是模型具備足夠的領域知識。

圖片

圖片

圖片

發現 3: 更難的任務、更長的思考、更長的推理時間

圖片

圖示:對 Qwen2.5-72B 在三個資料集上使用不同策略(從左到右分別為 Vanilla CoT、LongStep 和 LongMonolog)的準確率和平均輸出 token 數量進行比較。

團隊發現對於更難的任務,需要更多的輸出 token 才能從推理時間計算中獲益。為了解釋任務難度的層級,假設回答 JAMA 中的問題比 Medbullets 和 MedQA 中的問題更具挑戰性,因為 JAMA 呈現了更復雜的真實世界場景,即使是專有模型在 JAMA 上的表現也不理想。此外,Medbullets 的平均難度要高於比 MedQA,因為 MedQA 部分包括了 USMLE 的 Step 1 題目。透過進一步分析輸出長度,Qwen2.5-72B 在回答 JAMA 問題時的平均輸出 token 數量為 1,076,而在 Medbullets 中為 917,在 MedQA 中為 873。

發現4: 推理時擴充套件和模型大小的關係

圖片

圖示:引導不同模型逐步解決問題的收益。

團隊發現對於較小引數規模的模型(例如 7B 或 20B),推理時間的增加反而可能導致效能下降,甚至有時無法遵循指令的輸出格式。在難度更高的資料集(如 JAMA)上,這種現象尤為明顯。JAMA 包含複雜的真實臨床案例,要求廣泛的領域知識進行分析,效能缺陷尤為顯著。

另一個值得注意的觀察是,引數較少的模型(如 Qwen2.5-32B)從推理時擴充套件(inference-time scaling)中獲得的收益小於較大容量的模型。

基於這些發現,團隊提出了以下假設:推理時間中長時間思維的有效性依賴於足夠的能力。這在醫學領域尤為重要,因為解決臨床問題需要理解和生成複雜且細緻的文字能力,以及廣泛的知識儲備,包括疾病、藥理學和治療方案等方面的知識。

泛化能力與未來方向

透過仔細分析構造的資料,團隊發現這些資料並不侷限於以提供的選項作為輸出的參考。在推理過程中,模型將這些選項內化為啟發式方法,生成更接近完整診斷的輸出,包括差異候選項的生成及排除,而不是逐一討論選項。

為了驗證使用 Journey Learning 資料訓練的模型在鑑別診斷中的有效性,團隊進行了一項初步研究:團隊移除了多項選擇題的選項,並讓模型自由地進行回答。

為確保公平,團隊選擇了 2024 年 JAMA Clinical Challenges 中發表的案例,而訓練資料則收集於 2023 年 10 月之前。儘管訓練資料包括了多項選擇題選項的提供,但實驗結果表明,使用長形式推理的模型更傾向於分析更廣泛的潛在疾病,並整合多種背景資訊和知識,從而得出更為精確的結論。這些發現為未來的研究方向提供了有價值的啟示。

圖片

圖片

圖片

總結

透過團隊對推理時擴充套件(inference-time scaling)在醫學領域應用的初步探索,研究團隊發現這一方法在處理複雜推理任務時表現出巨大的潛力。

本研究展示了推理時擴充套件(inference-time scaling)顯著提升了模型在諸如 MedQA、Medbullets 和 JAMA 臨床挑戰等基準測試中的表現。在僅用 500 個訓練樣本的情況下,模型準確率提升達 6% 至 11%。

研究團隊希望:透過持續探索和迭代改進,提高推理時擴充套件在解決實際醫學問題中的可解釋性和有效性;透過專注於協作研究和開放資源共享,加強計算機技術與實際醫學應用之間的聯絡,最終改善診斷準確性、患者治療結果和醫療效率。

相關文章