RL 基礎 | 如何復現 PPO,以及一些踩坑經歷

MoonOut發表於2024-11-21

最近在復現 PPO 跑 MiniGrid,記錄一下…

這裡跑的環境是 Empty-5x5 和 8x8,都是簡單環境,主要驗證 PPO 實現是否正確。

01 Proximal policy Optimization(PPO)

(參考:知乎 | Proximal Policy Optimization (PPO) 演算法理解:從策略梯度開始

首先,策略梯度方法 的梯度形式是

\[\nabla_\theta J(\theta)\approx \frac1n \sum_{i=0}^{n-1} R(\tau_i) \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t|s_t) \tag1 \]

然而,傳統策略梯度方法容易一步走的太多,以至於越過了中間比較好的點(在參考知乎部落格裡稱為 overshooting)。一個直觀的想法是限制策略每次不要更新太多,比如去約束 新策略 舊策略之間的 KL 散度(公式是 plog(p/q)):

\[D_{KL}(\pi_\theta | \pi_{\theta+\Delta \theta}) = \mathbb E_{s,a} \pi_\theta(a|s)\log\frac{\pi_\theta(a|s)}{\pi_{\theta+\Delta \theta}(a|s)} \le \epsilon \tag2 \]

我們把這個約束進行拉格朗日鬆弛,將它變成一個懲罰項:

\[\Delta\theta^* = \arg\max_{\Delta\theta} J(\theta+\Delta\theta) - \lambda [D_{KL}(\pi_\theta | \pi_{\theta+\Delta \theta})-\epsilon] \tag3 \]

然後再使用一些數學近似技巧,可以得到自然策略梯度(NPG)演算法。

NPG 演算法貌似還有種種問題,比如 KL 散度的約束太緊,導致每次更新後的策略效能沒有提升。我們希望每次策略更新後都帶來效能提升,因此計算 新策略 舊策略之間 預期回報的差異。這裡採用計算 advantage 的方式:

\[J(\pi_{\theta+\Delta\theta})=J(\pi_{\theta})+\mathbb E_{\tau\sim\pi_{\theta+\Delta\theta}}\sum_{t=0}^\infty \gamma^tA^{\pi_{\theta}}(s_t,a_t) \tag{4} \]

其中優勢函式(advantage)的定義是:

\[A^{\pi_{\theta}}(s_t,a_t)=\mathbb E(Q^{\pi_{\theta}}(s_t,a_t)-V^{\pi_{\theta}}(s_t)) \tag{5} \]

在公式 (4) 中,我們計算的 advantage 是在 新策略 的期望下的。但是,在新策略下蒙特卡洛取樣(rollout)來算 advantage 期望太麻煩了,因此我們在原策略下 rollout,並進行 importance sampling,假裝計算的是新策略下的 advantage。這個 advantage 被稱為替代優勢(surrogate advantage):

\[\mathcal{L}_{\pi_{\theta}}\left(\pi_{\theta+\Delta\theta}\right) = J\left(\pi_{\theta+\Delta\theta}\right)-J\left(\pi_{\theta}\right)\approx E_{s\sim\rho_{\pi\theta}}\frac{\pi_{\theta+\Delta\theta}(a\mid s)}{\pi_{\theta}(a\mid s)} A^{\pi_{\theta}}(s, a) \tag6 \]

所產生的近似誤差,貌似可以用兩種策略之間最壞情況的 KL 散度表示:

\[J(\pi_{\theta+\Delta\theta})-J(\pi_{\theta})\geq\mathcal{L}_{\pi\theta}(\pi_{\theta+\Delta\theta})-CD_{KL}^{\max}(\pi_{\theta}||\pi_{\theta+\Delta\theta}) \tag7 \]

其中 C 是一個常數。這貌似就是 TRPO 的單調改進定理,即,如果我們改進下限 RHS,我們也會將目標 LHS 改進至少相同的量。

基於 TRPO 演算法,我們可以得到 PPO 演算法。PPO Penalty 跟 TRPO 比較相近:

\[\Delta\theta^{*}=\underset{\Delta\theta}{\text{argmax}} \Big[\mathcal{L}_{\theta+\Delta\theta}(\theta+\Delta\theta)-\beta\cdot \mathcal{D}_{KL}(\pi_{\theta}\parallel\pi_{\theta+\Delta\theta})\Big] \tag 8 \]

其中,KL 散度懲罰的 β 是啟發式確定的:PPO 會設定一個目標散度 \(\delta\),如果最終更新的散度超過目標散度的 1.5 倍,則下一次迭代我們將加倍 β 來加重懲罰。相反,如果更新太小,我們將 β 減半,從而擴大信任域。

接下來是 PPO Clip,這貌似是目前最常用的 PPO。PPO Penalty 用 β 來懲罰策略變化,而 PPO Clip 與此不同,直接限制策略可以改變的範圍。我們重新定義 surrogate advantage:

\[\begin{aligned} \mathcal{L}_{\pi_{\theta}}^{CLIP}(\pi_{\theta_{k}}) = \mathbb E_{\tau\sim\pi_{\theta}}\bigg[\sum_{t=0}^{T} \min\Big( & \rho_{t}(\pi_{\theta}, \pi_{\theta_{k}})A_{t}^{\pi_{\theta_{k}}}, \\ & \text{clip} (\rho_{t}(\pi_{\theta},\pi_{\theta_{k}}), 1-\epsilon, 1+\epsilon) A_{t}^{\pi_{\theta_{k}}} \Big)\bigg] \end{aligned} \tag 9 \]

其中, \(\rho_{t}\) 為重要性取樣的 ratio:

\[\rho_{t}(\theta)=\frac{\pi_{\theta}(a_{t}\mid s_{t})}{\pi_{\theta_{k}}(a_{t}\mid s_{t})} \tag{10} \]

公式 (9) 中,min 括號裡的第一項是 ratio 和 advantage 相乘,代表新策略下的 advantage;min 括號裡的第二項是對 ration 進行的 clip 與 advantage 的相乘。這個 min 貌似可以限制策略變化不要太大。

02 如何復現 PPO(參考 stable baselines3 和 clean RL)

  • stable baselines3 的 PPO:https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/ppo/ppo.py
  • clean RL 的 PPO:https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py

程式碼主要結構如下,以 stable baselines3 為例:(僅保留主要結構,相當於虛擬碼,不保證正確性)

import torch
import torch.nn.functional as F
import numpy as np

# 1. collect rollout
self.policy.eval()
rollout_buffer.reset()
while not done:
    actions, values, log_probs = self.policy(self._last_obs)
    new_obs, rewards, dones, infos = env.step(clipped_actions)
    rollout_buffer.add(
        self._last_obs, actions, rewards,
        self._last_episode_starts, values, log_probs,
    )
    self._last_obs = new_obs
    self._last_episode_starts = dones

with torch.no_grad():
    # Compute value for the last timestep
    values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) 

rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)


# 2. policy optimization
for rollout_data in self.rollout_buffer.get(self.batch_size):
    actions = rollout_data.actions
    values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
    advantages = rollout_data.advantages
    # Normalize advantage
    if self.normalize_advantage and len(advantages) > 1:
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # ratio between old and new policy, should be one at the first iteration
    ratio = torch.exp(log_prob - rollout_data.old_log_prob)

    # clipped surrogate loss
    policy_loss_1 = advantages * ratio
    policy_loss_2 = advantages * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
    policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

    # Value loss using the TD(gae_lambda) target
    value_loss = F.mse_loss(rollout_data.returns, values_pred)

    # Entropy loss favor exploration
    entropy_loss = -torch.mean(entropy)

    loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

    # Optimization step
    self.policy.optimizer.zero_grad()
    loss.backward()
    # Clip grad norm
    torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
    self.policy.optimizer.step()

大致流程:收集當前策略的 rollout → 計算 advantage → 策略最佳化。

計算 advantage 是由 rollout_buffer.compute_returns_and_advantage 函式實現的:

rb = rollout_buffer
last_gae_lam = 0
for step in reversed(range(buffer_size)):
    if step == buffer_size - 1:
        next_non_terminal = 1.0 - dones.astype(np.float32)
        next_values = last_values
    else:
        next_non_terminal = 1.0 - rb.episode_starts[step + 1]
        next_values = rb.values[step + 1]
    delta = rb.rewards[step] + gamma * next_values * next_non_terminal - rb.values[step]  # (1)
    last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam  # (2)
    rb.advantages[step] = last_gae_lam
rb.returns = rb.advantages + rb.values

其中,

  • (1) 行透過類似於 TD error 的形式(A = r + γV(s') - V(s)),計算當前 t 時刻的 advantage;
  • (2) 行則是把 t+1 時刻的 advantage 乘 gamma 和 gae_lambda 傳遞過來。

03 記錄一些踩坑經歷

  1. PPO 在收集 rollout 的時候,要在分佈裡取樣,而非採用 argmax 動作,否則沒有 exploration。(PPO 在分佈裡取樣 action,這樣來保證探索,而非使用 epsilon greedy 等機制;聽說 epsilon greedy 機制是 value-based 方法用的)
  2. 如果 policy 網路裡有(比如說)batch norm,rollout 時應該把 policy 開 eval 模式,這樣就不會出錯。
  3. (但是,不要加 batch norm,加 batch norm 效能就不好了。聽說 RL 不能加 batch norm)
  4. minigrid 簡單環境,RNN 加不加貌似都可以(?)
  5. 在算 entropy loss 的時候,要用真 entropy,從 Categorical 分佈裡得到的 entropy;不要用 -logprob 近似的,不然會導致策略分佈 熵變得很小 炸掉。


相關文章