無限生成影片,還能規劃決策,擴散強制整合下一token預測與全序列擴散

机器之心發表於2024-07-23

當前,採用下一 token 預測正規化的自迴歸大型語言模型已經風靡全球,同時網際網路上的大量合成影像和影片也早已讓我們見識到了擴散模型的強大之處。

近日,MIT CSAIL 的一個研究團隊(一作為 MIT 在讀博士陳博遠)成功地將全序列擴散模型與下一 token 模型的強大能力統合到了一起,提出了一種訓練和取樣正規化:Diffusion Forcing(DF)。

圖片
  • 論文標題:Diffusion Forcing:Next-token Prediction Meets Full-Sequence Diffusion

  • 論文地址:https://arxiv.org/pdf/2407.01392

  • 專案網站:https://boyuan.space/diffusion-forcing

  • 程式碼地址:https://github.com/buoyancy99/diffusion-forcing

如下所示,擴散強制在一致性和穩定性方面都明顯勝過全序列擴散和教師強制這兩種方法。

圖片

在該框架中,每個 token 都關聯了一個隨機的、獨立的噪聲水平,並且可使用一種共享的下一 token 預測模型或下幾 token 預測模型根據任意的、獨立的、每 token 的方案對 token 進行去噪。

該方法的研究靈感來自這一觀察:對 token 加噪聲的過程就是一種形式的部分掩碼過程 —— 零噪聲就意味著未對 token 加掩碼,而完整噪聲則是完全掩蔽 token。因此,DF 可強迫模型學習去除任何可變有噪聲 token 集合的掩碼(圖 2)。
圖片
與此同時,透過將預測方法引數化為多個下一 token 預測模型的組合,該系統可以靈活地生成不同長度的序列,並以組合方式泛化到新的軌跡(圖 1)。
圖片
該團隊將用於序列生成的 DF 實現成了因果擴散強制(Causal Diffusion Forcing/CDF),其中未來 token 透過一個因果架構依賴於過去 token。他們訓練該模型一次性去噪序列的所有 token(其中每個 token 都有獨立的噪聲水平)。

在取樣期間,CDF 會將一個高斯噪聲幀序列逐漸地去噪成潔淨的樣本,其中不同幀在每個去噪步驟可能會有不同的噪聲水平。類似於下一 token 預測模型,CDF 可以生成長度可變的序列;不同於下一 token 預測,CDF 的表現非常穩定 —— 不管是預測接下來的一個 token,還是未來的數千 token,甚至是連續 token。

此外,類似於全序列擴散,它也可接收引導,從而實現高獎勵生成。透過協同利用因果關係、靈活的範圍和可變噪聲排程,CDF 能實現一項新功能:蒙特卡洛樹引導(MCTG)。相比於非因果全序列擴散模型,MCTG 能極大提升高獎勵生成的取樣率。圖 1 給出了這些能力的概況。

Diffusion Forcing(擴散強制)

1、將加噪過程視為部分掩碼

首先,我們可以將任意 token 集合(不管是否為序列)視為一個透過 t 索引的有序集合。那麼,使用教師強制(teacher forcing)訓練下一 token 預測便可被解釋成掩蔽掉時間 t 的每個 token x_t 基於過去 x_{1:t−1} 預測它們。

對於序列,可將這種操作描述成:沿時間軸執行掩碼。我們可以將全序列前向擴散(即逐漸向資料圖片新增噪聲的過程)看作一種部分掩碼(partial masking),這可被稱為「沿噪聲軸執行掩碼)。

事實上,在 K 步加噪之後,圖片(大概)就是白噪聲了,不再有任何有關原資料的資訊。

如圖 2 所示,該團隊建立了一個統一視角來看待沿這兩個軸的掩碼。

2、擴散強制:不同 token 的噪聲水平不同

擴散強制(DF)框架可用於訓練和取樣任意序列長度的有噪聲 token 圖片,其中的關鍵在於每個 token 的噪聲水平 k_t 會隨時間步驟而變化。

這篇論文關注的重點是時間序列資料,因此他們透過一種因果架構例項化了 DF,並由此得到了因果擴散強制(CDF)。簡單來說,這是使用基礎迴圈神經網路(RNN)獲得的一種最小實現。

權重為 θ 的 RNN 維護著獲悉過去 token 影響的隱藏狀態 z_t,其會透過一個迴圈層根據動態圖片而演化。當獲得輸入噪聲觀察圖片時,就以馬爾可夫方式更新該隱藏狀態。

當 k_t=0 時,這就是貝葉斯過濾中的後驗更新;而當 k_t= K(純噪聲、無資訊)時,這就等價於建模貝葉斯過濾中的「後驗分佈」p_θ(z_t | z_{t−1})。

給定隱藏狀態 z_t,觀察模型 p_θ(x_t^0 | z_t) 的目標是預測 x_t;這個單元的輸入 - 輸出行為與標準的條件擴散模型一樣:以條件變數 z_{t−1} 和有噪聲 token 為輸入,預測無噪聲的 x_t=x_t^0,並由此間接地透過仿射重新引數化預測噪聲 ε^{k_t}。因此,我們就可以直接使用經典的擴散目標來訓練(因果)擴散強制。根據噪聲預測結果 ε_θ,可以對上述單元進行引數化。然後,透過最小化以下損失來找到引數 θ:
圖片
演算法 1 給出了虛擬碼。重點在於,該損失捕獲了貝葉斯過濾和條件擴散的關鍵元素。該團隊也進一步重新推斷了用於擴散強制的擴散模型訓練中的常用技術,詳見原論文的附錄部分。他們也得出了一個非正式的定理。
圖片
定理 3.1(非正式)。擴散強制訓練流程(演算法 1)是在期望對數似然圖片上最佳化證據下限(ELBO)的重新加權,其中期望值會在噪聲水平上平均,而圖片是根據前向過程加噪。此外,在適當條件下,最佳化 (3.1) 式還可以同時最大化所有噪聲水平序列的似然下限。

擴散強制取樣和所得到的能力

演算法 2 描述了取樣過程,其定義是:在二維的 M × T 網格 K ∈ [K]^{M×T} 上指定噪聲排程;其中列對應於時間步驟 t,m 索引的行則決定了噪聲水平。
圖片
為了生成長度為 T 的整個序列,先將 token x_{1:T} 初始化為白噪聲,對應於噪聲水平 k = K。然後沿著網格逐行向下迭代,並從左到右逐列去噪,直到噪聲水平達到 K。到最後一行 m = 0 時,token 的噪聲已清理乾淨,即噪聲水平為 K_{0,t} ≡ 0。

這個取樣正規化會帶來如下新能力:

  • 讓自迴歸生成變得穩定
  • 保持未來的不確定
  • 長期引導能力

將擴散強制用於靈活的序列決策

擴散強制的新能力也帶來了新的可能性。該團隊基於此為序列決策(SDM)設計了一種全新框架,並且將其成功應用到了機器人和自主智慧體領域。

首先,定義一個馬爾可夫決策過程,該過程具有動態 p (s_{t+1}|s_t, a_t)、觀察 p (o_t|s_t) 和獎勵 p (r_t|s_t, a_t)。這裡的目標是訓練一個策略 π(a_t|o_{1:t}),使得軌跡 圖片的預期累積獎勵最大化。這裡分配 token x_t = [a_t, r_t, o_{t+1}]。一條軌跡就是一個序列 x_{1:T},其長度可能是可變的;訓練方式則如演算法 1 所示。

在執行過程的每一步 t,都有一個隱藏狀態 z_{t-1} 總結過去的無噪聲 token x_{1:t-1}。基於這個隱藏狀態,根據演算法 2 取樣一個規劃圖片,其中圖片包含預測的動作、獎勵和觀察。H 是一個前向觀察視窗,類似於模型預測控制中的未來預測。在採用了規劃的動作之後,環境會得到一個獎勵和下一個觀察,從而得到下一個 token。其中隱藏狀態可以根據後驗 p_θ(z_t|z_{t−1}, x_t, 0) 獲得更新。

該框架既可以作為策略,也可以作為規劃器,其優勢包括:

  • 具有靈活的規劃範圍
  • 可實現靈活的獎勵引導
  • 能實現蒙特卡洛樹引導(MCTG),從而實現未來不確定性

實驗

該團隊評估了擴散強制作為生成序列模型的優勢,其中涉及影片和時間序列預測規劃模仿學習等多種應用。

影片預測:一致且穩定的序列生成和無限展開

針對影片生成式建模任務,他們基於 Minecraft 遊戲影片和 DMLab 導航為因果擴散強制訓練了一個卷積 RNN 實現。

圖 3 展示了擴散強制與基準的定性結果。
圖片
可以看到,擴散強制能穩定地展開,甚至能超過其訓練範圍;而教師強制和全序列擴散基準會很快發散。

擴散規劃:MCTG、因果不確定性、靈活的範圍控制

擴散強制的能力能為決策帶來獨有的好處。該團隊使用一種標準的離線強化學習框架 D4RL 評估了新提出的決策框架。
圖片
表 1 給出了定性和定量的評估結果。可以看到,擴散強制在全部 6 個環境中都優於 Diffuser 和所有基準
圖片
可控的序列組合生成

該團隊發現,僅需修改取樣方案,就可以靈活地組合訓練時間觀察到的序列的子序列。

他們使用一個 2D 軌跡資料集進行了實驗:在一個方形平面上,所有軌跡都是始於一角並最終到達對角,形成一種十字形。

如上圖 1 所示,當不需要組合行為時,可讓 DF 保持完整記憶,複製十字形的分佈。當需要組合時,可讓模型使用 MPC 無記憶地生成更短的規劃,從而實現對這個十字形的子軌跡的縫合,得到 V 形軌跡。

機器人:長範圍模仿學習和穩健的視覺運動控制

擴散強制也為真實機器人的視覺運動控制帶來了新的機會。

模仿學習是一種常用的機器人操控技術,即學習專家演示的觀察到動作的對映。但是,缺乏記憶往往會讓模仿學習難以完成長範圍的任務。DF 不僅能緩解這個短板,還能讓模仿學習更穩健。

使用記憶進行模仿學習。透過遙控 Franka 機器人,該團隊收集了一個影片和動作資料集。如圖 4 所示,任務就是利用第三個位置交換蘋果和橘子的位置。水果的初始位置是隨機的,因此可能的目標狀態有兩個。
圖片
此外,當第三個位置有一個水果時,就無法透過當前觀察推斷出所需結果 —— 策略必須記住初始配置才能決定移動哪個水果。不同於常用的行為克隆方法,DF 可以自然地將記憶整合進自己的隱藏狀態中。結果發現,DF 能實現 80% 的成功率,而擴散策略(當前最佳的無記憶模仿學習演算法)卻失敗了。無限生成影片,還能規劃決策,擴散強制整合下一token預測與全序列擴散
此外,DF 還能更穩健地應對噪聲並助益機器人預訓練。

時間序列預測:擴散強制是一種優秀的通用序列模型

對於多變數時間序列預測任務,該團隊的研究表明 DF 足以與之前的擴散模型和基於 Transformer 的模型媲美。

更多技術細節和實驗結果請參閱原論文。

相關文章