用於決策的世界模型 -- 論文 World Models (2018) & PlaNet (2019) 講解

伊犁纯流莱發表於2025-01-14

參考資料:

  • [2411.14499] Understanding World or Predicting Future? A Comprehensive Survey of World Models
  • [1803.10122] World Models
  • Learning Latent Dynamics for Planning from Pixels
  • Kaixhin/PlaNet: Deep Planning Network: Control from pixels by latent planning with learned dynamics

世界模型

簡介

世界模型:一種理解世界當前狀態預測其未來動態的工具。

世界模型的兩個主要功能

  1. 構建內部表徵以理解世界運作機制。
  2. 預測未來狀態以模擬和指導決策。

分類

image

圖片來自[2411.14499] Understanding World or Predicting Future? A Comprehensive Survey of World Models.

  • 作者按照模型的側重點不同,將世界模型分成兩個大類,即:
    • Internal Representations.
    • Future Predictions.
  • 經常能在網上刷到的LeCun力推世界模型,說的是JEPA.
  • 左邊分支的世界模型也可以做"future prediction",作為學習模型引數過程的一個副產物吧 (視覺模組的reconstruction)。

這裡討論的是兩篇world model for decision-making的文章。

World Models (2018)

image

AI社群中,首篇系統性介紹世界模型的文章。

人類的心理模型

簡單可以概括成以下幾點:

  • 對於外部世界的大量資訊流,人腦能夠學習到外部世界時空資訊的抽象表示,作為我們對外部世界的"建模"。
  • 我們所看到一切都基於腦中模型對未來的預測。

image

  • 我們能夠基於這個預測模型本能地行動,在面對危險時做出快速的反射性行為。

打棒球的例子: 擊球手需要在毫秒級別的時間內決定如何揮棒 —— 這比視覺訊號到大腦的時間還要短。

在之後的世界模型結構和實驗中,都可以看到這個心理模型的影子。

模型結構

image

世界模型主要由兩個模組組成:視覺模組、記憶模組。
1. 視覺模組:將外部世界的高維觀測,壓縮成低維的特徵。
2. 記憶模組:整合歷史資訊,預測未來。

控制器會利用世界模型給出的資訊進行決策。

視覺模組

作者在文章中使用VAE的Encoder部分作為視覺模組。
image

記憶模組

作者在文章中使用MDN-RNN作為記憶模組。
image

  • MDN指的是mixture density networks,就是一個建模混合模型的網路,文中使用的是高斯混合模型 (GMM),此時神經網路除了輸出每個高斯分佈的均值和標準差,還需要輸出用於選擇高斯分佈的類別分佈。
  • MDN會接受一個temperature引數\(\tau\),用於調整不確定性。
  • 在圖中,MDN-RNN建模的是\(P(z_{t+1}\mid a_t, z_t, h_t)\).
  • 除了隱狀態之外,記憶模組可能還需要建模其他東西,比如獎勵\(P(r_{t+1} \mid a_t, z_t, h_t)\),遊戲結束的訊號\(P(\text{done}_{t+1} \mid a_t, z_t, h_t)\).

NOTE:為什麼要使用混合模型,即使VAE的隱變數空間只是一個對角高斯?作者的解釋是:混合模型中的離散部分 (選擇哪一個高斯組分),有利於建模環境中的離散隨機事件。比如說NPC在平靜狀態和警覺狀態下的表現不同。

控制器

作者將整個模型的複雜性都集中到了視覺和記憶模組,有意使得控制器的結構儘可能簡單:

\[a_t = W_c[z_t~~h_t] + b_c \]

就是單層的神經網路。

模型訓練和實驗

文章官網World Models,有gif演示,而且可以試玩模型"夢中"的遊戲。

訓練

兩個實驗都是先單獨訓練世界模型 (無監督):

  1. 使用隨機策略收集一系列的遊戲影像。
  2. 使用這些影像訓練好VAE。
  3. 在訓練好的VAE基礎上,訓練好MDN-RNN。

之後部署世界模型並訓練控制器。兩個實驗的主要區別在部署:

  • Car Racing實驗:直接在實際環境部署,訓練好了控制器之後,又給出了在模型"夢中"的模擬。image

  • VizDoom實驗:先在"夢中"部署,訓練好了控制器之後,再將整個模型轉移到實際環境檢視效果。image

NOTE:在兩個實驗中,世界模型都沒有建模環境的獎勵。第一個實驗中,獎勵只在訓練控制器的時候由實際環境給出;第二個實驗中,指標是存活時間,不需要獎勵。

REMARK:訓練成功之後,模型實際上成為了遊戲的"模擬器",學習到了遊戲邏輯 (角色中彈後會重新開始)、敵人行為 (按一定時間間隔發射子彈)、物理機制 (子彈飛行速度)等。

實驗

Car Racing:
image

VizDoom:
image

消融實驗1 -- 視覺模組+記憶模組的優越性
在Car Racing中,消融實驗顯示,單獨的視覺模組效果不如一整個的世界模型 (但是也已經超過了DQN和A3C)

消融實驗2 -- 用tau調整隨機性
在VizDoom實驗中,由於模型並非完全精確,控制器可能會利用模型的缺陷來在模擬器中達到高分,一旦部署到實際環境,控制器就不行了。

為了防止這一點,MDN-RNN預測的是具有隨機性的環境,並透過調整不確定性引數\(\tau\)來控制隨機性。在實驗中,\(\tau=1.15\)時效果最好。

\(\tau=0.1\)時,模型幾乎是確定性的,這時候敵人甚至無法發射子彈,所以出現了在模擬器中非常高分,實際環境中卻非常低分的情況。

跑分對比實驗

  • Car Racing實驗:取得的分數超過了先前的基於深度強化學習的方法,如DQN、A3C.
  • VizDoom實驗:在夢中學會了如何躲避怪物的子彈,部署到實際環境後的存活時長也超過了先前。

迭代訓練過程

本文的實驗環境簡單,所以是使用隨機策略取樣,分別訓練三個模組。面對更復雜的任務,可能需要三個模組一起訓練,但是本文只是提了一下記憶模組和控制器一起訓練的流程:
image

三個模組一起訓練的好處是:

  1. 視覺模組會傾向於學習到有利於當前任務的特徵。
  2. 記憶模組可以對控制器進行學習,控制器又可以基於記憶模組繼續改進,如此往復。
  3. 可以使用訓練中的控制器進行軌跡取樣而不是隨機策略。

Learning Latent Dynamics for Planning from Pixels (2019)

image

相對於上一篇,這篇的改進:

  1. 假定了環境是部分可觀測馬爾可夫決策過程 (POMDP),世界模型就是在學習這個POMDP.
  2. 給出了一套結合模型預測控制 (MPC) 方法的訓練過程 —— Deep Planning Network (PlaNet).
  3. 提出基於確定性和隨機性結合的狀態空間模型 (RSSM),而不是僅有確定性狀態的RNN和僅有隨機性狀態的SSM.
  4. 給出了適用於多步預測的變分推斷方法 —— latent overshooting.

Problem setup

假定實際的環境是一個POMDP:
image

目標是學習到一個策略,能夠最大化期望累積回報\(\mathbb E[\sum r_t]\)

Deep planning network

這裡先講世界模型+MPC的學習和規劃演算法。

image

while迴圈內部,總體上分成三個部分:模型學習,實時規劃+資料收集,更新資料庫。

模型學習

從資料庫中隨機抽取觀測序列的小批次,然後使用梯度方法學習。

實時規劃+資料收集

總體上就是一個有限時間域的MPC框架,在每個time step按三步走:

  1. Observe:獲得當前時刻的狀態。由於這裡在隱狀態空間進行規劃,所以需要從歷史的觀測資料中推斷當前狀態 (透過隱變數的後驗機率)。
  2. Predict and plan:利用當前學習到的模型,解一個有限時間域的最優控制問題,獲得一串動作序列。本文中的planner使用的是cross entropy method (CEM).
  3. Act:對環境使用這串動作序列的第一個動作\(a_t\),移動到下一個time step. 這裡用了一個trick,把取得的動作\(a_t\)重複了\(R\)次 (用相同的action,連續走了\(R\)步),取reward的總和作為當前時刻的reward,取最終的第\(R\)觀測\(o_{t+1}^R\)作為下一個時刻的觀測\(o_{t+1}\)

更新資料庫

將上一個部分收集到的觀測序列加入到資料庫中,以供世界模型的進一步更新。

NOTE:相對於model-free RL演算法,model-based planning的一大優勢就是資料利用率提高了。體現在planning取得的觀測序列可以反覆用於世界模型的學習。

RSSM

這種模型也叫:Non-linear Kalman filter, sequential VAE, deep variational bayes filter,看了一眼相關的文章,好像要從頭到尾講明白 (像VAE那樣) 比較複雜。

這裡淺淺講一下世界模型的結構以及訓練的Loss。

Latent state-space model

image

使用下面的encoder來近似後驗機率:

\[q(s_{\le t} \mid o_{\le t},a_{<t}) = \prod_{t=1}^T q(s_t\mid s_{t-1},a_{t-1},o_t) \]

都使用神經網路引數化的高斯分佈表示,其中observation model和encoder用的是卷積網路。

Training Objective

透過最大化log Evidence來訓練:

\[\arg\max \ln p(o_{\le t} \mid a_{<t}) \]

接下來推導ELBO.

先拆成邊際化的形式

\[\ln p(o_{\le T} \mid a_{<T}) = \ln \int p(o_{\le T}, s_{\le T} \mid a_{<T}) \text{d}s \]

把聯合機率拆開

\[\ln p(o_{\le T} \mid a_{<T}) = \ln \int p(o_{\le T} \mid s_{\le T}, a_{<T}) p(s_{\le T} \mid a_{<T}) \text{d}s \]

寫成期望的形式

\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{p(s_{\le t}\mid a_{<t})}[ p(o_{\le t} \mid s_{\le t}, a_{<t}) ] \]

利用重要性取樣方法,轉變成從encoder取樣

\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[ p(o_{\le t} \mid s_{\le t}, a_{<t}) p(s_{\le t}\mid a_{<t}) / q(s_{\le t}\mid o_{\le t},a_{<t})] \]

鏈式分解,並利用模型的條件獨立性化簡 (機率圖參考下面的)

\[\ln p(o_{\le t} \mid a_{<t}) = \ln \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[\prod p(o_t \mid s_t) p(s_t \mid s_{t-1},a_{t-1}) / q(s_{\le t}\mid o_{\le t},a_{<t})] \]

根據Jensen不等式,\(\ln \mathbb E[x] \ge \mathbb E[\ln(x)]\)

\[\ln p(o_{\le t} \mid a_{<t}) \ge \mathbb E_{q(s_{\le t}\mid o_{\le t},a_{<t})}[\sum_t \ln p(o_t \mid s_t) + \ln p(s_t \mid s_{t-1},a_{t-1}) - \ln q(s_{\le t}\mid o_{\le t},a_{<t})] \]

右邊可以寫成reconstruction + KL的形式,最後就是
image

確定性和隨機性結合 - RSSM

image

  • 純確定性的世界模型:模型難以預測多種可能的未來情況;容易被planner利用模型缺陷 (在World Models中,透過MDN新增隨機性來緩解這一點,但本質還是確定性的)
  • 純隨機性的世界模型:模型難以記住資訊,導致產生前後不一致的預測結果。

所以作者考慮將確定性和隨機性結合,稱這種結構為RSSM.

相對於上一篇,把記憶模組換成了RSSM。

Latent Overshooting

之前討論的都是\(s_t \to s_{t+1}\)的單步預測,如果每次單步預測都準確無誤,那多步預測肯定也沒問題。但是由於模型本身有侷限,所以不一定能很好的推廣到多部預測。

於是作者考慮了直接進行跨步的預測,先透過對中間幾步隱變數邊際化得到了跨步預測的轉移
image

並且推導了針對跨步預測的變分bound
image

把考慮不同的步幅\(d\),求和,就得到latent overshooting的目標函式
image

實驗結果

image

DeepMind control suite環境:影像作為觀測,連續動作空間。

消融實驗

  • 驗證PlaNet的資料收集過程有優勢。Random Collection指的是用隨機策略收集資料而不是透過MPC;Random shooting指的是使用了MPC框架,但是不使用CEM,而是直接從1000條隨機採的動作序列裡選最好的那條。最後PlaNet在大部分情況都明顯好於另外兩種。
    image

  • RSSM和SSM、GRU的對比。觀察到RSSM明顯好於後兩者,表明了確定性+隨機性結合的優勢。
    image

  • 是否加入latent overshooting作為變分目標。觀察到Latent overshooting使RSSM的表現輕微變差,但是在一些任務上讓DRNN的表現變好了。
    image

跑分對比實驗
image

  • PlaNet的分數能打敗A3C。
  • PlaNet的分數總體不如D4PG,但是大部分任務相差不多。
  • PlaNet在所有任務上,資料利用率都好於D4PG.
  • PlaNet (CEM + 世界模型) 和 CEM + true simulator對比只差了一些,體現出世界模型較好地學習到了環境。

六個任務一起訓練
每次迴圈中,agent面對的可能是不同的環境,所以資料庫中抽取出來的軌跡也是打亂的。
image

最後跑分不如單獨訓練,但是體現出了agent能夠自己判斷出面對的是哪個任務了。

程式碼選講

程式碼來自:Kaixhin/PlaNet: Deep Planning Network: Control from pixels by latent planning with learned dynamics

主要是看看transition model和模型訓練過程。解釋都在註釋裡,有部分註釋是程式碼庫原有的。

Transition model

class TransitionModel(jit.ScriptModule):
  __constants__ = ['min_std_dev']

  def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1):
    super().__init__()
    self.act_fn = getattr(F, activation_function)
    self.min_std_dev = min_std_dev
    self.fc_embed_state_action = nn.Linear(state_size + action_size, belief_size) # combine s_t and a_t to comb(s_t, a_t)
    self.rnn = nn.GRUCell(belief_size, belief_size) # from comb(s_t, a_t), h_t to h_t+1
    self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size) # from h_t to z_t
    self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size) # parameterized prior of s_t, from z_t to mean and std
    self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size) # from h_t and e_t to z_t
    self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size) # parameterized posterior of s_t, from z_t to mean and std

  # Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations
  # Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off):
  # t :  0  1  2  3  4  5
  # o :    -X--X--X--X--X-  設定了初始的隱狀態是None,所以不考慮0時刻的obs
  # a : -X--X--X--X--X-     不考慮最後一個action,因為最後一個action沒有後續的obs
  # n : -X--X--X--X--X-
  # pb: -X-
  # ps: -X-
  # b : -x--X--X--X--X--X-
  # s : -x--X--X--X--X--X-

  # 輸入的shape都是(time_step, batch_size, *)
  @jit.script_method
  def forward(self, prev_state:torch.Tensor, actions:torch.Tensor, prev_belief:torch.Tensor, observations:Optional[torch.Tensor]=None, nonterminals:Optional[torch.Tensor]=None) -> List[torch.Tensor]:
    # 後面都是動態更新,為了保留grad,不能使用單個tensor作為buffer,所以建立了幾個list
    T = actions.size(0) + 1 # 實際需要的list長度,參考上面的圖
    beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = \
      [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T
    beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state # 0時刻賦初值

    # 每次迴圈開始,是已知t時刻的資訊,進一步計算t+1時刻的資訊
    for t in range(T - 1):
      # 根據情況合適的s,因為模型可以在脫離observations的情況下自己預測
      # 如果observations為None,則使用先驗狀態 (模型一步步生成出來的),否則使用後驗狀態 (根據歷史的obs和action推斷出來的)
      _state = prior_states[t] if observations is None else posterior_states[t] 
      # terminal則說明這段序列已經結束了,所以把狀態mask掉 (就是0)
      _state = _state if nonterminals is None else _state * nonterminals[t]  

      # 注意下面每一塊的hidden是臨時變數,表示的是不同的意思

      # 計算確定性隱狀態h = f(s_t, a_t, h_t)
      hidden = self.act_fn(self.fc_embed_state_action(torch.cat([_state, actions[t]], dim=1))) # s和a先拼在一起
      beliefs[t + 1] = self.rnn(hidden, beliefs[t]) # 對應機率圖中從s,a,h到h的實線

      # 計算隱狀態s的先驗 p(s_t|s_t-1,a_t-1)
      hidden = self.act_fn(self.fc_embed_belief_prior(beliefs[t + 1])) # 對應機率圖中從h到s的實線
      prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1)
      prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev # Trick: 使用softplus來保證std_devs為正,並且使用min_std_dev來保證std_devs不會太小
      prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])     

      # 計算隱狀態s的後驗 q(s_t|o≤t,a<t)
      if observations is not None: # 只有observations不為None時,才計算後驗
        t_ = t - 1  # 這是實現的問題,因為傳進來的是obs[1:],所以應該用t_+1才能索引到對應的obs
        hidden = self.act_fn(self.fc_embed_belief_posterior(torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1))) # 對應機率圖中的兩條虛線
        posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1)
        posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev
        posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1])

    # 返回h,s,以及先驗和後驗的均值和方差
    hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)]
    if observations is not None:
      hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)]
    return hidden

世界模型訓練
只擷取了一小部分,重點看loss func是如何計算的。

  # Model fitting
  losses = []
  for s in tqdm(range(args.collect_interval)):
    # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
    observations, actions, rewards, nonterminals = D.sample(args.batch_size, args.chunk_size)  # Transitions start at time t = 0

    # Create initial belief and state for time t = 0
    init_belief, init_state = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device)

    # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
    # 一次把整個隱狀態序列全部計算出來
    beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs =\
      transition_model(init_state, actions[:-1], init_belief, bottle(encoder, (observations[1:], )), nonterminals[:-1])

    # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
    # Reconstruction loss都使用MSE
    # mean(dim=(0, 1))對batch和time進行平均
    observation_loss =\
      F.mse_loss(bottle(observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum(dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
    reward_loss =\
      F.mse_loss(bottle(reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))
    # KL loss, 計算了後驗q(s_t|o≤t,a<t)和先驗p(s_t|s_t-1,a_t-1)的KL散度
    kl_loss =\
      torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), free_nats).mean(dim=(0, 1))  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out

"""
後面的部分略
"""

相關文章