在影像生成領域佔據主導地位的擴散模型,開始挑戰強化學習智慧體。
世界模型提供了一種以安全且樣本高效的方式訓練強化學習智慧體的方法。近期,世界模型主要對離散潛在變數序列進行操作來模擬環境動態。
然而,這種壓縮為緊湊離散表徵的方式可能會忽略對強化學習很重要的視覺細節。另一方面,擴散模型已成為影像生成的主要方法,對離散潛在模型提出了挑戰。
受這種正規化轉變的推動,來自日內瓦大學、愛丁堡大學、微軟研究院的研究者聯合提出一種在擴散世界模型中訓練的強化學習智慧體 —— DIAMOND(DIffusion As a Model Of eNvironment Dreams)。
論文地址:https://arxiv.org/abs/2405.12399
專案地址:https://github.com/eloialonso/diamond
論文標題:Diffusion for World Modeling: Visual Details Matter in Atari
DIAMOND 在 Atari 100k 基準測試中獲得了 1.46 的平均人類歸一化得分 (HNS),可以媲美完全在世界模型中訓練的智慧體的 SOTA 水平。該研究提供了定性分析來說明,DIAMOND 的設計選擇對於確保擴散世界模型的長期高效穩定是必要的。
此外,在影像空間中操作的好處是使擴散世界模型能夠成為環境的直接替代品,從而提供對世界模型和智慧體行為更深入的瞭解。特別地,該研究發現某些遊戲中效能的提高源於對關鍵視覺細節的更好建模。
方法介紹
接下來,本文介紹了 DIAMOND, 這是一種在擴散世界模型中訓練的強化學習智慧體。具體來說,研究者基於 2.2 節引入的漂移和擴散係數 f 和 g,這兩個係數對應於一種特定的擴散正規化選擇。此外,該研究還選擇了基於 Karras 等人提出的 EDM 公式。
首先定義一個擾動核,,其中, 是一個與擴散時間相關的實值函式,稱為噪聲時間表。這對應於將漂移和擴散係數設為 和。
接著使用 Karras 等人(2022)引入的網路預處理,同時引數化公式(5)中的,作為噪聲觀測值和神經網路 預測值的加權和:
得到公式(6)
其中為了簡潔定義,包含所有條件變數。
前處理器的選擇。選擇前處理器和,以保持網路輸入和輸出在任何噪聲水平 下的單位方差。 是噪聲水平的經驗轉換, 由 和資料分佈的標準差 給出,公式為
結合公式 5 和 6,得到訓練目標:
該研究使用標準的 U-Net 2D 來構建向量場,並保留一個包含過去 L 個觀測和動作的緩衝區,以此來對模型進行條件化。接下來他們將這些過去的觀測按通道方式與下一個帶噪觀測拼接,並透過自適應組歸一化層將動作輸入到 U-Net 的殘差塊中。正如在第 2.3 節和附錄 A 中討論的,有許多可能的取樣方法可以從訓練好的擴散模型中生成下一個觀測。雖然該研究釋出的程式碼庫支援多種取樣方案,但該研究發現尤拉方法在不需要額外的 NFE(函式評估次數)以及避免了高階取樣器或隨機取樣的不必要複雜性的情況下是有效的。
實驗
為了全面評估 DIAMOND,該研究使用了公認的 Atari 100k 基準測試,該基準測試包括 26 個遊戲,用於測試智慧體的廣泛能力。對於每個遊戲,智慧體只允許在環境中進行 100k 次操作,這大約相當於人類 2 小時的遊戲時間,以在評估前學習翫遊戲。作為參考,沒有限制的 Atari 智慧體通常訓練 5000 萬步,這相當於經驗的 500 倍增加。研究者從頭開始在每個遊戲上用 5 個隨機種子訓練 DIAMOND。每次執行大約使用 12GB 的 VRAM,在單個 Nvidia RTX 4090 上大約需要 2.9 天(總計 1.03 個 GPU 年)。
表 1 比較了在世界模型中訓練智慧體的不同得分:
圖 2 中提供了平均值和 IQM( Interquartile Mean )置信區間:
結果表明,DIAMOND 在基準測試中表現強勁,超過人類玩家在 11 個遊戲中的表現,並達到了 1.46 的 HNS 得分,這是完全在世界模型中訓練的智慧體的新紀錄。該研究還發現,DIAMOND 在需要捕捉細節的環境中表現特別出色,例如 Asterix、Breakout 和 Road Runner。
為了研究擴散變數的穩定性,該研究分析了自迴歸生成的想象軌跡(imagined trajectory),如下圖 3 所示:
該研究發現有些情況需要迭代求解器將取樣過程驅動到特定模式,如圖 4 所示的拳擊遊戲:
如圖 5 所示,與 IRIS 想象的軌跡相比,DIAMOND 想象的軌跡通常具有更高的視覺質量,並且更符合真實環境。
感興趣的讀者可以閱讀論文原文,瞭解更多研究內容。