WHALE來了,南大周志華團隊做出更強泛化的世界模型

机器之心發表於2024-11-13

世界模型又出新進展了,來自國內機構。

人類能夠在腦海中設想一個想象中的世界,以預測不同的動作可能導致不同的結果。受人類智慧這一方面的啟發,世界模型被設計用於抽象化現實世界的動態,並提供這種「如果…… 會怎樣」的預測。

因此,具身智慧體可以與世界模型進行互動,而不是直接與現實世界環境互動,以生成模擬資料,這些資料可以用於各種下游任務,包括反事實預測、離線策略評估、離線強化學習。

世界模型在具身環境的決策中起著至關重要的作用,使得在現實世界中成本高昂的探索成為可能。為了促進有效的決策,世界模型必須具備強大的泛化能力,以支援分佈外 (OOD) 區域的想象,並提供可靠的不確定性估計來評估模擬體驗的可信度,這兩者都對之前的可擴充套件方法提出了重大挑戰。

本文,來自南京大學、南棲仙策等機構的研究者引入了 WHALE(World models with beHavior-conditioning and retrAcing-rollout LEarning),這是一個用於學習可泛化世界模型的框架,由兩種可以與任何神經網路架構普遍結合的關鍵技術組成。

圖片

  • 論文地址:https://arxiv.org/pdf/2411.05619

  • 論文標題:WHALE: TOWARDS GENERALIZABLE AND SCALABLE WORLD MODELS FOR EMBODIED DECISION-MAKING

首先,在確定策略分佈差異是泛化誤差的主要來源的基礎上,作者引入了一種行為 - 條件(behavior-conditioning)技術來增強世界模型的泛化能力,該技術建立在策略條件模型學習的概念之上,旨在使模型能夠主動適應不同的行為,以減輕分佈偏移引起的外推誤差。

此外,作者還提出了一種簡單而有效的技術,稱為 retracing-rollout,以便對模型想象進行有效的不確定性估計。作為一種即插即用的解決方案, retracing-rollout 可以有效地應用於各種實施任務中的末端執行器姿態控制,而無需對訓練過程進行任何更改。

為了實現 WHALE 框架,作者提出了 Whale-ST,這是一個基於時空 transformer 的可擴充套件具身世界模型,旨在為現實世界的視覺控制任務提供忠實的長遠想象。

為了證實 Whale-ST 的有效性,作者在模擬的 Meta-World 基準和物理機器人平臺上進行了廣泛的實驗。

在模擬任務上的實驗結果表明,Whale-ST 在價值估計準確率和影片生成保真度方面均優於現有的世界模型學習方法。此外,作者還證明了基於 retracing-rollout 技術的 Whale-ST 可以有效捕獲模型預測誤差並使用想象的經驗增強離線策略最佳化。

作為進一步的舉措,作者引入了 Whale-X,這是一個具有 414M 引數的世界模型,該模型在 Open X-Embodiment 資料集中的 970k 個現實世界演示上進行了訓練。透過在完全沒見過的環境和機器人中的一些演示進行微調,Whale-X 在視覺、動作和任務視角中展示了強大的 OOD 通用性。此外,透過擴大預訓練資料集或模型引數,Whale-X 在預訓練和微調階段都表現出了令人印象深刻的可擴充套件性。

圖片

總結來說,這項工作的主要貢獻概述如下:

  • 作者引入了 WHALE,這是一個學習可泛化世界模型的框架,由兩項關鍵技術組成:行為 - 條件(behavior-conditioning)和 retracing-rollout,以解決世界模型在決策應用中的兩個主要挑戰:泛化和不確定性估計;

  • 透過整合 WHALE 的這兩種技術,作者提出了 Whale-ST,這是一種可擴充套件的基於時空 transformer 的世界模型,旨在實現更有效的決策,作者進一步提出了 Whale-X,這是一個在 970K 機器人演示上預訓練的 414M 引數世界模型;

  • 最後,作者進行了大量的實驗,以證明 Whale-ST 和 Whale-X 在模擬和現實世界任務中的卓越可擴充套件性和泛化性,突出了它們在增強決策方面的效果。

學習可泛化的世界模型以進行具身決策

世界模型中的序列決策通常需要智慧體探索超出訓練資料集的分佈外 (OOD) 區域。這要求世界模型表現出強大的泛化能力,使其能夠做出與現實世界動態密切相關的準確預測。同時,可靠地量化預測不確定性對於穩健的決策至關重要,這可以防止離線策略最佳化利用錯誤的模型預測。考慮到這些問題,作者提出了 WHALE,這是一個用於學習可泛化世界模型的框架,具有增強的泛化性和高效的不確定性估計。

用於泛化的行為 - 條件

根據公式(2)的誤差分解可知,世界模型的泛化誤差主要來源於策略分歧引起的誤差積累。

圖片

為了解決這個問題,一種可能的解決方案是將行為資訊嵌入到世界模型中,使得模型能夠主動識別策略的行為模式,並適應由策略引起的分佈偏移。

基於行為 - 條件,作者引入了一個學習目標,即從訓練軌跡中獲取行為嵌入,並整合學習到的嵌入。

作者希望將訓練軌跡 τ_H 中的決策模式提取到行為嵌入中,這讓人聯想到以歷史 τ_h 為條件的軌跡似然 ELBO(evidence lower bound)的最大化:

圖片

作者建議透過最大化 H 個決策步驟上的 ELBO 並調整類似於 β-VAE 的 KL 約束數量來學習行為嵌入:

圖片

這裡,KL 項將子軌跡的嵌入預測約束到每個時間步驟 h,鼓勵它們近似後驗編碼。這確保了表示保持策略一致,這意味著由相同策略生成的軌跡表現出相似的行為模式,從而表現出相似的表示。然後使用學習到的先驗預測器圖片從歷史 τ_h 中獲得行為嵌入 z_h,以便在世界模型學習期間進行行為調節,其中行為嵌入被接受為未來預測的額外協變數:

圖片

不確定性估計 Retracing-rollout

世界模型不可避免地會產生不準確和不可靠的樣本,先前的研究從理論和實驗上都證明,如果無限制地使用模型生成的資料,策略的效能可能會受到嚴重損害。因此,不確定性估計對於世界模型至關重要。

作者引入了一種新穎的不確定性估計方法,即 retracing-rollout。retracing-rollout 的核心創新在於引入了 retracing-action,它利用了具身控制中動作空間的語義結構,從而能夠更準確、更高效地估計基於 Transformer 的世界模型的不確定性。

圖片

接下來作者首先介紹了 retracing-action,具體地說,retracing-action 可以等效替代任何給定的動作序列,形式如公式(5),其中圖片表示動作 a_i 第 j 維的值。

圖片

接下來是一個全新的概念:Retracing-rollout。

具體來說:假設給定一個「回溯步驟」k,整個過程開始於從當前時間步 t,回溯到時間步 t-k,將 o_t−k 作為起始幀。

然後,執行一個回溯動作圖片,從 o_t−k 開始,生成相應的結果 o_k+1。

在實際操作中,為了避免圖片超出動作空間的範圍,回溯動作被分解為 k 步。在每一步中,前六個維度的動作被設定為圖片,而最後一個維度圖片保持不變。透過這種方式,模型可以透過多步回溯產生期望的結果。

為了估計某一時間點 (o_t,a_t) 的不確定性,採用多種回溯步驟生成不同的回溯 - 軌跡預測結果。具體來說,要計算不同回溯 - 軌跡輸出與不使用回溯的輸出之間的「感知損失」。同時,引入動態模型的預測熵,透過將「感知損失」和預測熵相乘,得到最終的不確定性估計結果。

與基於整合的其他方法不同,retracing-rollout 方法不需要在訓練階段進行任何修改,因此相比整合方法,它顯著減少了計算成本。

作者在論文中還給出了具體的例項。圖 3 展示了 Whale-ST 的整體架構。具體來說,Whale-ST 包含三個主要元件:行為調節模型、影片 tokenizer 和動態模型。這些模組採用了時空 transformer 架構。

這些設計顯著簡化了計算需求,從相對於序列長度的二次依賴關係簡化為線性依賴關係,從而降低了模型訓練的記憶體使用量和計算成本,同時提高了模型推理速度。

圖片

實驗

該團隊在模擬任務和現實世界任務上進行了廣泛的實驗,主要是為了回答以下問題:

  • Whale-ST 在模擬任務上與其他基線相比表現如何?行為 - 條件和 retracing-rollout 策略有效嗎?

  • Whale-X 在現實世界任務上的表現如何?Whale-X 能否從網際網路規模資料的預訓練中受益?

  • Whale-X 的可擴充套件性如何?增加模型引數或預訓練資料是否能提高在現實世界任務上的表現?

模擬任務中的 Whale-ST

該團隊在 Meta-World 基準測試上開展實驗。Meta-World 是一個包含多種視覺操作任務的測試集。研究者們構建了一個包含 6 萬條軌跡的訓練資料集,這些軌跡是從 20 個不同的任務中收集來的。模型學習演算法需要使用這些資料從頭開始訓練。

研究團隊將 Whale-ST 與 FitVid、MCVD、DreamerV3、iVideoGPT 進行了對比。評估指標如下:
  • 預測準確性:驗證模型是否能夠正確估計給定動作序列的值,具體透過值差、回報相關性 (Return Correlation) 和 Regret 進行評估;

  • 影片保真度:研究團隊採用 FVD、PSNR、LPIPS 和 SSIM 來衡量影片軌跡生成的質量。

下表展示了預測準確性的結果,其中,Whale-ST 在所有三個指標上都表現出色。在 64 × 64 的解析度下,Whale-ST 的值差與 DreamerV3 的最高分非常接近。當在更高解析度 256 × 256 測試時,Whale-ST 的表現進一步提升,取得了最小的值差和最高的回報相關性,反映了 Whale-ST 能更細緻地理解動態環境。
圖片
表 2 展示了影片保真度的結果,Whale-ST 在所有指標上均優於其他方法,特別是 FVD 具有顯著優勢。
圖片
不確定性估計

針對不確定性,研究團隊比較了 retracing-rollout 與兩種基準方法:

(1)基於熵的方法:研究團隊採用基於 Transformer 的動態模型,它透過計算模型輸出的預測熵來量化不確定性
(2)基於整合的方法:研究團隊訓練了三個獨立的動態模型,然後透過比較每個模型生成的影像之間的畫素級差異來估計不確定性。

具體來說,他們從模型誤差預測和離線強化學習兩個角度進行評估。

下表展示了模型誤差預測的結果,在所有 5 個任務中,retracing-rollout 均優於其他基線方法。與基於整合的方法相比,retracing-rollout 提升了 500%,與基於熵的方法相比,提高了 50%。
圖片
下圖展示了離線 MBRL 的結果,retracing-rollout 在 5 個任務中的 3 個任務中收斂得更好、具備更強的穩定性。特別是在關水龍頭和滑盤子任務中,retracing-rollout 是唯一能夠穩定收斂的方法,而其他方法在訓練後期出現了不同程度的效能下降。
圖片
Whale-X 在真實世界中的表現

為了評估 Whale-X 在實際物理環境中的泛化能力,研究團隊在 ARX5 機器人上進行了全面實驗。

與預訓練資料不同,評估任務調整了攝像機角度和背景等,增加了對世界模型的挑戰。他們收集了每個任務 60 條軌跡的資料集用於微調,任務包括開箱、推盤、投球和移動瓶子,還設計了多個模型從未接觸過的任務來測試模型的視覺、運動和任務泛化能力。

如圖 5 所示,Whale-X 在真實世界中展現出了明顯的優勢。

具體來說:
圖片
1. 與沒有行為 - 條件的模型相比,Whale-X 的一致性提高了 63%,表明該機制顯著提升了 OOD 泛化能力;
2. 在 97 萬個樣本上進行預訓練的 Whale-X,比從零開始訓練的模型具有更高的一致性,凸顯了大規模網際網路資料預訓練的優勢;
3. 增加模型引數能夠提升世界模型的泛化能力。Whale-X-base(203M)動態模型在三個未見任務中的一致性率是 77M 版本的三倍。

此外,影片生成質量與一致性的結果一致,如表 4 所示。透過行為 - 條件策略、大規模預訓練資料集和擴充套件模型引數,三種策略結合,顯著提高了模型的 OOD 泛化能力,尤其是在生成高質量影片方面。
圖片
擴充套件性

固定影片 token 和行為 - 條件這兩個部分不變,僅調整模型的引數量和預訓練資料集的大小,Whale-X 的擴充性如何呢?

研究團隊在預訓練階段訓練了四個動態模型,引數數量從 39M 到 456M 不等,結果如圖 7 的前兩幅圖所示。
圖片
這些結果表明,Whale-X 展現出強大的擴充套件性:無論是增加預訓練資料還是增加模型引數,都會降低訓練 loss。

除此之外,研究團隊還驗證了更大的模型在微調階段是否能夠展現更好的效能。

為此,他們微調了一系列動態模型,結果如圖 7 最左側所示。不難發現,經過微調後,更大的模型在測試資料上表現出更低的 loss,進一步突顯了 Whale-X 在真實任務中出色的擴充套件性。

視覺化
  • 定性評估

圖 1 展示了在 Meta-World、Open X-Embodiment 和研究團隊設計的真實任務上的定性評估結果。
圖片
結果表明,Whale-ST 和 Whale-X 能夠生成高保真度的影片軌跡,尤其是在長時間跨度的軌跡生成過程中,保持了影片的質量和一致性。
  • 可控生成

圖 8 展示了 Whale-X 在控制性和泛化性方面的強大能力。給定一個未見過的動作序列,Whale-X 能夠生成與人類理解相符的影片,學習動作與機器人手臂移動之間的因果聯絡。
圖片
  • 行為條件視覺化

透過 t-SNE 視覺化,研究表明 Whale-X 成功地學習到行為嵌入,能夠區分不同策略之間的差異。例如,對於同一任務,不同的策略會有不同的行為表示,而噪聲策略的嵌入則介於專家策略和隨機策略之間,體現了模型在策略建模上的合理性。此外,專家策略在不同任務中的嵌入也能被區分,而隨機策略則無法區分,表明模型更擅長表示和區分策略,而不是任務本身。
圖片
更多研究細節,請參考原文。

參考連結:https://arxiv.org/abs/2411.05619

相關文章