參考資料:
- [2411.14499] Understanding World or Predicting Future? A Comprehensive Survey of World Models
- [1803.10122] World Models
- Learning Latent Dynamics for Planning from Pixels
- Kaixhin/PlaNet: Deep Planning Network: Control from pixels by latent planning with learned dynamics
世界模型
簡介
世界模型:一種理解世界當前狀態或預測其未來動態的工具。
世界模型的兩個主要功能:
- 構建內部表徵以理解世界運作機制。
- 預測未來狀態以模擬和指導決策。
分類
圖片來自[2411.14499] Understanding World or Predicting Future? A Comprehensive Survey of World Models.
- 作者按照模型的側重點不同,將世界模型分成兩個大類,即:
- Internal Representations.
- Future Predictions.
- 經常能在網上刷到的LeCun力推世界模型,說的是JEPA.
- 左邊分支的世界模型也可以做"future prediction",作為學習模型引數過程的一個副產物吧 (視覺模組的reconstruction)。
這裡討論的是兩篇world model for decision-making的文章。
World Models (2018)
AI社群中,首篇系統性介紹世界模型的文章。
人類的心理模型
簡單可以概括成以下幾點:
- 對於外部世界的大量資訊流,人腦能夠學習到外部世界時空資訊的抽象表示,作為我們對外部世界的"建模"。
- 我們所看到一切都基於腦中模型對未來的預測。
- 我們能夠基於這個預測模型本能地行動,在面對危險時做出快速的反射性行為。
打棒球的例子: 擊球手需要在毫秒級別的時間內決定如何揮棒 —— 這比視覺訊號到大腦的時間還要短。
在之後的世界模型結構和實驗中,都可以看到這個心理模型的影子。
模型結構
世界模型主要由兩個模組組成:視覺模組、記憶模組。
1. 視覺模組:將外部世界的高維觀測,壓縮成低維的特徵。
2. 記憶模組:整合歷史資訊,預測未來。
控制器會利用世界模型給出的資訊進行決策。
視覺模組
作者在文章中使用VAE的Encoder部分作為視覺模組。
記憶模組
作者在文章中使用MDN-RNN作為記憶模組。
- MDN指的是mixture density networks,就是一個建模混合模型的網路,文中使用的是高斯混合模型 (GMM),此時神經網路除了輸出每個高斯分佈的均值和標準差,還需要輸出用於選擇高斯分佈的類別分佈。
- MDN會接受一個temperature引數\(\tau\),用於調整不確定性。
- 在圖中,MDN-RNN建模的是\(P(z_{t+1}\mid a_t, z_t, h_t)\).
- 除了隱狀態之外,記憶模組可能還需要建模其他東西,比如獎勵\(P(r_{t+1} \mid a_t, z_t, h_t)\),遊戲結束的訊號\(P(\text{done}_{t+1} \mid a_t, z_t, h_t)\).
NOTE:為什麼要使用混合模型,即使VAE的隱變數空間只是一個對角高斯?作者的解釋是:混合模型中的離散部分 (選擇哪一個高斯組分),有利於建模環境中的離散隨機事件。比如說NPC在平靜狀態和警覺狀態下的表現不同。
控制器
作者將整個模型的複雜性都集中到了視覺和記憶模組,有意使得控制器的結構儘可能簡單:
就是單層的神經網路。
模型訓練和實驗
文章官網World Models,有gif演示,而且可以試玩模型"夢中"的遊戲。
訓練
兩個實驗都是先單獨訓練世界模型 (無監督):
- 使用隨機策略收集一系列的遊戲影像。
- 使用這些影像訓練好VAE。
- 在訓練好的VAE基礎上,訓練好MDN-RNN。
之後部署世界模型並訓練控制器。兩個實驗的主要區別在部署:
-
Car Racing實驗:直接在實際環境部署,訓練好了控制器之後,又給出了在模型"夢中"的模擬。
-
VizDoom實驗:先在"夢中"部署,訓練好了控制器之後,再將整個模型轉移到實際環境檢視效果。
NOTE:在兩個實驗中,世界模型都沒有建模環境的獎勵。第一個實驗中,獎勵只在訓練控制器的時候由實際環境給出;第二個實驗中,指標是存活時間,不需要獎勵。
REMARK:訓練成功之後,模型實際上成為了遊戲的"模擬器",學習到了遊戲邏輯 (角色中彈後會重新開始)、敵人行為 (按一定時間間隔發射子彈)、物理機制 (子彈飛行速度)等。
實驗
Car Racing:
VizDoom:
消融實驗1 -- 視覺模組+記憶模組的優越性
在Car Racing中,消融實驗顯示,單獨的視覺模組效果不如一整個的世界模型 (但是也已經超過了DQN和A3C)
消融實驗2 -- 用tau調整隨機性
在VizDoom實驗中,由於模型並非完全精確,控制器可能會利用模型的缺陷來在模擬器中達到高分,一旦部署到實際環境,控制器就不行了。
為了防止這一點,MDN-RNN預測的是具有隨機性的環境,並透過調整不確定性引數\(\tau\)來控制隨機性。在實驗中,\(\tau=1.15\)時效果最好。
當\(\tau=0.1\)時,模型幾乎是確定性的,這時候敵人甚至無法發射子彈,所以出現了在模擬器中非常高分,實際環境中卻非常低分的情況。
跑分對比實驗
- Car Racing實驗:取得的分數超過了先前的基於深度強化學習的方法,如DQN、A3C.
- VizDoom實驗:在夢中學會了如何躲避怪物的子彈,部署到實際環境後的存活時長也超過了先前。
迭代訓練過程
本文的實驗環境簡單,所以是使用隨機策略取樣,分別訓練三個模組。面對更復雜的任務,可能需要三個模組一起訓練,但是本文只是提了一下記憶模組和控制器一起訓練的流程:
三個模組一起訓練的好處是:
- 視覺模組會傾向於學習到有利於當前任務的特徵。
- 記憶模組可以對控制器進行學習,控制器又可以基於記憶模組繼續改進,如此往復。
- 可以使用訓練中的控制器進行軌跡取樣而不是隨機策略。
Learning Latent Dynamics for Planning from Pixels (2019)
相對於上一篇,這篇的改進:
- 假定了環境是部分可觀測馬爾可夫決策過程 (POMDP),世界模型就是在學習這個POMDP.
- 給出了一套結合模型預測控制 (MPC) 方法的訓練過程 —— Deep Planning Network (PlaNet).
- 提出基於確定性和隨機性結合的狀態空間模型 (RSSM),而不是僅有確定性狀態的RNN和僅有隨機性狀態的SSM.
- 給出了適用於多步預測的變分推斷方法 —— latent overshooting.
Problem setup
假定實際的環境是一個POMDP:
目標是學習到一個策略,能夠最大化期望累積回報\(\mathbb E[\sum r_t]\)。
Deep planning network
這裡先講世界模型+MPC的學習和規劃演算法。
while迴圈內部,總體上分成三個部分:模型學習,實時規劃+資料收集,更新資料庫。
模型學習
從資料庫中隨機抽取觀測序列的小批次,然後使用梯度方法學習。
實時規劃+資料收集
總體上就是一個有限時間域的MPC框架,在每個time step按三步走:
- Observe:獲得當前時刻的狀態。由於這裡在隱狀態空間進行規劃,所以需要從歷史的觀測資料中推斷當前狀態 (透過隱變數的後驗機率)。
- Predict and plan:利用當前學習到的模型,解一個有限時間域的最優控制問題,獲得一串動作序列。本文中的planner使用的是cross entropy method (CEM).
- Act:對環境使用這串動作序列的第一個動作\(a_t\),移動到下一個time step. 這裡用了一個trick,把取得的動作\(a_t\)重複了\(R\)次 (用相同的action,連續走了\(R\)步),取reward的總和作為當前時刻的reward,取最終的第\(R\)觀測\(o_{t+1}^R\)作為下一個時刻的觀測\(o_{t+1}\)。
更新資料庫
將上一個部分收集到的觀測序列加入到資料庫中,以供世界模型的進一步更新。
NOTE:相對於model-free RL演算法,model-based planning的一大優勢就是資料利用率提高了。體現在planning取得的觀測序列可以反覆用於世界模型的學習。
RSSM
這種模型也叫:Non-linear Kalman filter, sequential VAE, deep variational bayes filter,看了一眼相關的文章,好像要從頭到尾講明白 (像VAE那樣) 比較複雜。
這裡淺淺講一下世界模型的結構以及訓練的Loss。
Latent state-space model
使用下面的encoder來近似後驗機率:
都使用神經網路引數化的高斯分佈表示,其中observation model和encoder用的是卷積網路。
Training Objective
透過最大化log Evidence來訓練:
接下來推導ELBO.
先拆成邊際化的形式
把聯合機率拆開
寫成期望的形式
利用重要性取樣方法,轉變成從encoder取樣
鏈式分解,並利用模型的條件獨立性化簡 (機率圖參考下面的)
根據Jensen不等式,\(\ln \mathbb E[x] \ge \mathbb E[\ln(x)]\)
右邊可以寫成reconstruction + KL的形式,最後就是
確定性和隨機性結合 - RSSM
- 純確定性的世界模型:模型難以預測多種可能的未來情況;容易被planner利用模型缺陷 (在World Models中,透過MDN新增隨機性來緩解這一點,但本質還是確定性的)
- 純隨機性的世界模型:模型難以記住資訊,導致產生前後不一致的預測結果。
所以作者考慮將確定性和隨機性結合,稱這種結構為RSSM.
相對於上一篇,把記憶模組換成了RSSM。
Latent Overshooting
之前討論的都是\(s_t \to s_{t+1}\)的單步預測,如果每次單步預測都準確無誤,那多步預測肯定也沒問題。但是由於模型本身有侷限,所以不一定能很好的推廣到多部預測。
於是作者考慮了直接進行跨步的預測,先透過對中間幾步隱變數邊際化得到了跨步預測的轉移
並且推導了針對跨步預測的變分bound
把考慮不同的步幅\(d\),求和,就得到latent overshooting的目標函式
實驗結果
DeepMind control suite環境:影像作為觀測,連續動作空間。
消融實驗
-
驗證PlaNet的資料收集過程有優勢。Random Collection指的是用隨機策略收集資料而不是透過MPC;Random shooting指的是使用了MPC框架,但是不使用CEM,而是直接從1000條隨機採的動作序列裡選最好的那條。最後PlaNet在大部分情況都明顯好於另外兩種。
-
RSSM和SSM、GRU的對比。觀察到RSSM明顯好於後兩者,表明了確定性+隨機性結合的優勢。
-
是否加入latent overshooting作為變分目標。觀察到Latent overshooting使RSSM的表現輕微變差,但是在一些任務上讓DRNN的表現變好了。
跑分對比實驗
- PlaNet的分數能打敗A3C。
- PlaNet的分數總體不如D4PG,但是大部分任務相差不多。
- PlaNet在所有任務上,資料利用率都好於D4PG.
- PlaNet (CEM + 世界模型) 和 CEM + true simulator對比只差了一些,體現出世界模型較好地學習到了環境。
六個任務一起訓練
每次迴圈中,agent面對的可能是不同的環境,所以資料庫中抽取出來的軌跡也是打亂的。
最後跑分不如單獨訓練,但是體現出了agent能夠自己判斷出面對的是哪個任務了。
程式碼選講
程式碼來自:Kaixhin/PlaNet: Deep Planning Network: Control from pixels by latent planning with learned dynamics
主要是看看transition model和模型訓練過程。解釋都在註釋裡,有部分註釋是程式碼庫原有的。
Transition model
class TransitionModel(jit.ScriptModule):
__constants__ = ['min_std_dev']
def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1):
super().__init__()
self.act_fn = getattr(F, activation_function)
self.min_std_dev = min_std_dev
self.fc_embed_state_action = nn.Linear(state_size + action_size, belief_size) # combine s_t and a_t to comb(s_t, a_t)
self.rnn = nn.GRUCell(belief_size, belief_size) # from comb(s_t, a_t), h_t to h_t+1
self.fc_embed_belief_prior = nn.Linear(belief_size, hidden_size) # from h_t to z_t
self.fc_state_prior = nn.Linear(hidden_size, 2 * state_size) # parameterized prior of s_t, from z_t to mean and std
self.fc_embed_belief_posterior = nn.Linear(belief_size + embedding_size, hidden_size) # from h_t and e_t to z_t
self.fc_state_posterior = nn.Linear(hidden_size, 2 * state_size) # parameterized posterior of s_t, from z_t to mean and std
# Operates over (previous) state, (previous) actions, (previous) belief, (previous) nonterminals (mask), and (current) observations
# Diagram of expected inputs and outputs for T = 5 (-x- signifying beginning of output belief/state that gets sliced off):
# t : 0 1 2 3 4 5
# o : -X--X--X--X--X- 設定了初始的隱狀態是None,所以不考慮0時刻的obs
# a : -X--X--X--X--X- 不考慮最後一個action,因為最後一個action沒有後續的obs
# n : -X--X--X--X--X-
# pb: -X-
# ps: -X-
# b : -x--X--X--X--X--X-
# s : -x--X--X--X--X--X-
# 輸入的shape都是(time_step, batch_size, *)
@jit.script_method
def forward(self, prev_state:torch.Tensor, actions:torch.Tensor, prev_belief:torch.Tensor, observations:Optional[torch.Tensor]=None, nonterminals:Optional[torch.Tensor]=None) -> List[torch.Tensor]:
# 後面都是動態更新,為了保留grad,不能使用單個tensor作為buffer,所以建立了幾個list
T = actions.size(0) + 1 # 實際需要的list長度,參考上面的圖
beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = \
[torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T, [torch.empty(0)] * T
beliefs[0], prior_states[0], posterior_states[0] = prev_belief, prev_state, prev_state # 0時刻賦初值
# 每次迴圈開始,是已知t時刻的資訊,進一步計算t+1時刻的資訊
for t in range(T - 1):
# 根據情況合適的s,因為模型可以在脫離observations的情況下自己預測
# 如果observations為None,則使用先驗狀態 (模型一步步生成出來的),否則使用後驗狀態 (根據歷史的obs和action推斷出來的)
_state = prior_states[t] if observations is None else posterior_states[t]
# terminal則說明這段序列已經結束了,所以把狀態mask掉 (就是0)
_state = _state if nonterminals is None else _state * nonterminals[t]
# 注意下面每一塊的hidden是臨時變數,表示的是不同的意思
# 計算確定性隱狀態h = f(s_t, a_t, h_t)
hidden = self.act_fn(self.fc_embed_state_action(torch.cat([_state, actions[t]], dim=1))) # s和a先拼在一起
beliefs[t + 1] = self.rnn(hidden, beliefs[t]) # 對應機率圖中從s,a,h到h的實線
# 計算隱狀態s的先驗 p(s_t|s_t-1,a_t-1)
hidden = self.act_fn(self.fc_embed_belief_prior(beliefs[t + 1])) # 對應機率圖中從h到s的實線
prior_means[t + 1], _prior_std_dev = torch.chunk(self.fc_state_prior(hidden), 2, dim=1)
prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + self.min_std_dev # Trick: 使用softplus來保證std_devs為正,並且使用min_std_dev來保證std_devs不會太小
prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])
# 計算隱狀態s的後驗 q(s_t|o≤t,a<t)
if observations is not None: # 只有observations不為None時,才計算後驗
t_ = t - 1 # 這是實現的問題,因為傳進來的是obs[1:],所以應該用t_+1才能索引到對應的obs
hidden = self.act_fn(self.fc_embed_belief_posterior(torch.cat([beliefs[t + 1], observations[t_ + 1]], dim=1))) # 對應機率圖中的兩條虛線
posterior_means[t + 1], _posterior_std_dev = torch.chunk(self.fc_state_posterior(hidden), 2, dim=1)
posterior_std_devs[t + 1] = F.softplus(_posterior_std_dev) + self.min_std_dev
posterior_states[t + 1] = posterior_means[t + 1] + posterior_std_devs[t + 1] * torch.randn_like(posterior_means[t + 1])
# 返回h,s,以及先驗和後驗的均值和方差
hidden = [torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0)]
if observations is not None:
hidden += [torch.stack(posterior_states[1:], dim=0), torch.stack(posterior_means[1:], dim=0), torch.stack(posterior_std_devs[1:], dim=0)]
return hidden
世界模型訓練
只擷取了一小部分,重點看loss func是如何計算的。
# Model fitting
losses = []
for s in tqdm(range(args.collect_interval)):
# Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
observations, actions, rewards, nonterminals = D.sample(args.batch_size, args.chunk_size) # Transitions start at time t = 0
# Create initial belief and state for time t = 0
init_belief, init_state = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device)
# Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
# 一次把整個隱狀態序列全部計算出來
beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs =\
transition_model(init_state, actions[:-1], init_belief, bottle(encoder, (observations[1:], )), nonterminals[:-1])
# Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
# Reconstruction loss都使用MSE
# mean(dim=(0, 1))對batch和time進行平均
observation_loss =\
F.mse_loss(bottle(observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum(dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
reward_loss =\
F.mse_loss(bottle(reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))
# KL loss, 計算了後驗q(s_t|o≤t,a<t)和先驗p(s_t|s_t-1,a_t-1)的KL散度
kl_loss =\
torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), free_nats).mean(dim=(0, 1)) # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
"""
後面的部分略
"""