BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍

機器之心發表於2020-05-28
pixel-based RL 演算法逆襲,BAIR 提出將對比學習與 RL 相結合的演算法,其 sample-efficiency 匹敵 state-based RL。

此次研究的本質在於回答一個問題—使用影像作為觀測值(pixel-based)的 RL 是否能夠和以座標狀態作為觀測值的 RL 一樣有效?傳統意義上,大家普遍認為以影像為觀測值的 RL 資料效率較低,通常需要一億個互動的 step 來解決 Atari 遊戲那樣的基準測試任務。

研究人員介紹了 CURL:一種用於強化學習的無監督對比表徵。CURL 使用對比學習的方式從原始畫素中提取高階特徵,並在提取的特徵之上執行異策略控制。在 DeepMind Control Suite 和 Atari Games 中的複雜任務上,CURL 優於以前的 pixel-based 的方法(包括 model-based 和 model-free),在 100K 互動步驟基準測試中,其效能分別提高了 2.8 倍以及 1.6 倍。在 DeepMind Control Suite 上,CURL 是第一個幾乎與基於狀態特徵方法的 sample-efficiency 和效能所匹配的基於影像的演算法。

BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍


  • 論文連結:https://arxiv.org/abs/2004.04136

  • 網站:https://mishalaskin.github.io/curl/

  • GitHub 連結:https://github.com/MishaLaskin/curl


背景介紹

CURL 是將對比學習與 RL 相結合的通用框架。理論上,可以在 CURL pipeline 中使用任一 RL 演算法,無論是同策略還是異策略。對於連續控制基準而言(DM Control),研究團隊使用了較為熟知的 Soft Actor-Critic(SAC)(Haarnoja et al., 2018) ;而對於離散控制基準(Atari),研究團隊使用了 Rainbow DQN(Hessel et al., 2017))。下面,我們簡要回顧一下 SAC,Rainbow DQN 以及對比學習。

Soft Actor Critic

SAC 是一種異策略 RL 演算法,它最佳化了隨機策略,以最大化預期的軌跡回報。像其他 SOTA 端到端的 RL 演算法一樣,SAC 在從狀態觀察中解決任務時非常有效,但卻無法從畫素中學習有效的策略。

Rainbow

最好將 Rainbow DQN(Hessel et al., 2017)總結為在原來應用 Nature DQN 之上的多項改進(Mnih et al., 2015)。具體來說,深度 Q 網路(DQN)(Mnih et al., 2015)將異策略演算法 Q-Learning 與卷積神經網路作為函式逼近器相結合,將原始畫素對映到動作價值函式里。

除此之外,價值分佈強化學習(Bellemare et al., 2017)提出了一種透過 C51 演算法預測可能值函式 bin 上的分佈技術。Rainbow DQN 將上述所有技術組合在單一的異策略演算法中,用以實現 Atari 基準的最新 sample efficiency。此外,Rainbow 還使用了多步回報(Sutton et al.,1998)。

對比學習

CURL 的關鍵部分是使用對比無監督學習來學習高維資料的豐富表示的能力。對比學習可以理解為可區分的字典查詢任務。給定一個查詢 q、鍵 K= {k_0, k_1, . . . } 以及一個明確的 K(關於 q)P(K) = ({k+}, K \ {k+}) 分割槽,對比學習的目標是確保 q 與 k +的匹配程度比 K \ {k +} 中的任何的鍵都更大。在對比學習中,q,K,k +和 K \ {k +} 也分別稱為錨點(anchor),目標(targets),正樣本(positive), 負樣本(negatives)。

CURL 具體實現

CURL 透過將訓練對比目標作為批更新時的輔助損失函式,在最小程度上改變基礎 RL 演算法。在實驗中,研究者將 CURL 與兩個無模型 RL 演算法一同訓練——SAC 用於 DMControl 實驗,Rainbow DQN 用於 Atari 實驗。

總體框架概述

CURL 使用的例項判別方法(instance discrimination)類似於 SimCLR、MoC 和 CPC。大多數深度強化學習框架採用一系列堆疊在一起的影像作為輸入。因此,演算法在多個堆疊的幀中進行例項判別,而不是單一的影像例項。

研究者發現,使用類似於 MoCo 的動量編碼流程(momentum encoding)來處理目標,在 RL 中效能較好。最後,研究者使用一個類似於 CPC 中的雙線性內積來處理 InfoNCE score 方程,研究者發現效果比 MoCo 和 SimCLR 中的單位範數向量積(unit norm vector products)要好。對比表徵和 RL 演算法一同進行訓練,同時從對比目標和 Q 函式中獲得梯度。總體框架如下圖所示。

BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍

圖 2:CURL 總體框架示意圖

判別目標

選擇關於一個錨點的正、負樣本是對比表徵學習的其中一個關鍵組成部分。

不同於在同一張影像上的 image-patches,判別變換後的影像例項最佳化帶有 InfoNCE 損失項的簡化例項判別目標函式,並需要最小化對結構的調整。在 RL 設定下,選擇更簡化判別目標的理由主要有如下兩點:

  • 鑑於 RL 演算法十分脆弱,複雜的判別目標可能導致 RL 目標不穩定。

  • RL 演算法在動態生成的資料集上進行訓練,複雜的判別目標可能會顯著增加訓練所需時間。


因此,CURL 使用例項判別而不是 patch 判別。我們可將類似於 SimCLR 和 MoCo 這樣的對比例項判別設定,看做最大化一張影像與其對應增廣版本之間的共同資訊。

查詢-鍵值對的生成

類似於在影像設定下的例項判別,錨點和正觀測值是來自同一幅影像的兩個不同增廣值,而負觀測值則來源於其他影像。CURL 主要依靠隨機裁切資料增廣方法,從原始渲染影像中隨機裁切一個正方形的 patch。

研究者在批資料上使用隨機資料增廣,但在同一堆幀之間保持一致,以保留觀測值時間結構的資訊。資料增廣流程如圖 3 所示。

BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍

圖 3: 使用隨機裁剪產生錨點與其正樣本過程的直觀展示。

相似度量

區分目標中的另一個決定因素是用於測量查詢鍵對之間的內部乘積。CURL 採用雙線性內積 sim(q,k)= q^TW_k,其中 W 是學習的引數矩陣。研究團隊發現這種相似性度量的效能優於最近在計算機視覺(如 MoCo 和 SimCLR)中最新的對比學習方法中使用的標準化點積。

動量目標編碼

在 CURL 中使用對比學習的目標是訓練從高維畫素中能對映到更多語義隱狀態的編碼器。InfoNCE 是一種無監督的損失,它透過學習編碼器 f_q 和 f_k 將原始錨點(查詢)x_q 和目標(關鍵字)x_k 對映到潛在值 q = f_q(x_q) 和 k = f_k(x_k) 上,在此團隊應用相似點積。通常在錨點和目標對映之間共享相同的編碼器,即 f_q = f_k。

CURL 將幀-堆疊例項的識別與目標的動量編碼結合在一起,同時 RL 是在編碼器特徵之上執行的。

CURL 對比學習虛擬碼(PyTorch 風格)

BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍


實驗

研究者評估(i)sample-efficiency,方法具體為測量表現最佳的基線需要多少個互動步驟才能與 100k 互動步驟的 CURL 效能相匹配,以及(ii)透過測量 CURL 取得的週期回報值與最佳表現基線的比例來對效能層面的 100k 步驟進行衡量。換句話說,當談到資料或 sample-efficiency 時,其實指的是(i),而當談起效能時則指的是(ii)。

DMControl

在 DMControl 實驗中的主要發現:


  1. CURL 是我們在每個 DMControl 環境上進行基準測試的 SOTA ImageBased RL 演算法,用於根據現有的 Image-based 的基準進行取樣效率測試。在 DMControl100k 上,CURL 的效能比 Dreamer(Hafner 等人,2019)高 2.8 倍,這是一種領先的 model-based 的方法,並且資料效率高 9.9 倍。

  2. 從圖 7 所示的大多數 16 種 DMControl 環境中的狀態開始,僅靠畫素操作的 CURL 幾乎可以進行匹配(有時甚至超過)SAC 的取樣效率。它是基於 model-based,model-free,有輔助任務或者是沒有輔助任務。

  3. 在 50 萬步之內,CURL 解決了 16 個 DMControl 實驗中的大多數(收斂到接近 1000 的最佳分數)。它在短短 10 萬步的時間內就具有與 SOTA 相似效能的競爭力,並且大大優於該方案中的其他方法。


BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍

表 1. 在 500k(DMControl500k)和 100k(DMControl100k)環境步長基準下,CURL 和 DMControl 基準上獲得的基線得分。

BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍

圖 4. 相對於 SLAC、PlaNet、Pixel SAC 和 State SAC 基線,平均 10 個 seeds 的 CURL 耦合 SAC 效能。

BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍

圖 6. 要獲得與 CURL 在 100k 訓練步驟中所得分相同的分數,需要先行採用領先的 pixel-based 方法 Dreamer 的步驟數。

BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍

圖 7. 將 CURL 與 state-based 的 SAC 進行比較,在 16 個所選 DMControl 環境中的每個環境上執行 2 個 seeds。

Atari

在 Atari 實驗中的主要發現:

  1. 就大多數 26 項 Atari100k 實驗的資料效率而言,CURL 是 SOTA PixelBased RL 演算法。平均而言,在 Atari100k 上,CURL 的效能比 SimPLe 高 1.6 倍,而 Efficient Rainbow DQN 則高 2.5 倍。

  2. CURL 達到 24%的人類標準化分數(HNS),而 SimPLe 和 Efficient Rainbow DQN 分別達到 13.5%和 14.7%。CURL,SimPLe 和 Efficient Rainbow DQN 的平均 HNS 分別為 37.3%,39%和 23.8%。

  3. CURL 在三款遊戲 JamesBond(98.4%HNS),Freeway(94.2%HNS)和 Road Runner(86.5%HNS)上幾乎可以與人類的效率相提並論,這在所有 pixel-based 的 RL 演算法中均屬首例。


表 2. 透過 CURL 和以 10 萬個時間步長(Atari100k)為標準所獲得的分數。CURL 在 26 個環境中的 14 個環境中實現了 SOTA。

BAIR最新RL演算法超越谷歌Dreamer,效能提升2.8倍

專案介紹

安裝

所有相關項都在 conda_env.yml 檔案中。它們可以手動安裝,也可以使用以下命令安裝:

conda env create -f conda_env.yml

使用說明

要從基於影像的觀察中訓練 CURL agent 完成 cartpole swingup 任務,請從該目錄的根目錄執行 bash script/run.sh。run.sh 檔案包含以下命令,也可以對其進行修改以嘗試不同的環境/超引數

CUDA_VISIBLE_DEVICES=0 python train.py \
    --domain_name cartpole \
    --task_name swingup \
    --encoder_type pixel \
    --action_repeat 8 \
    --save_tb --pre_transform_image_size 100 --image_size 84 \
    --work_dir ./tmp \
    --agent curl_sac --frame_stack 3 \
    --seed -1 --critic_lr 1e-3 --actor_lr 1e-3 --eval_freq 10000 --batch_size 128 --num_train_steps 1000000

在控制檯中,應該看到如下所示的輸出:

| train | E: 221 | S: 28000 | D: 18.1 s | R: 785.2634 | BR: 3.8815 | A_LOSS: -305.7328 | CR_LOSS: 190.9854 | CU_LOSS: 0.0000
| train | E: 225 | S: 28500 | D: 18.6 s | R: 832.4937 | BR: 3.9644 | A_LOSS: -308.7789 | CR_LOSS: 126.0638 | CU_LOSS: 0.0000
| train | E: 229 | S: 29000 | D: 18.8 s | R: 683.6702 | BR: 3.7384 | A_LOSS: -311.3941 | CR_LOSS: 140.2573 | CU_LOSS: 0.0000
| train | E: 233 | S: 29500 | D: 19.6 s | R: 838.0947 | BR: 3.7254 | A_LOSS: -316.9415 | CR_LOSS: 136.5304 | CU_LOSS: 0.0000

cartpole swing up 的最高分數約為 845 分。而且,CURL 如何以小於 50k 的步長解決 visual cartpole。根據使用者的 GPU 不同而定,大約需要一個小時的訓練。同時作為參考,最新的端到端方法 D4PG 需要 50M 的 timesteps 來解決相同的問題。

Log abbreviation mapping:

train - training episode
E - total number of episodes 
S - total number of environment steps
D - duration in seconds to train 1 episode
R - mean episode reward
BR - average reward of sampled batch
A_LOSS - average loss of actor
CR_LOSS - average loss of critic
CU_LOSS - average loss of the CURL encoder

與執行相關的所有資料都儲存在指定的 working_dir 中。若要啟用模型或影片儲存,請使用--save_model 或--save_video。而對於所有可用的標誌,需要檢查 train.py。使用 tensorboard 執行來進行視覺化:

tensorboard --logdir log --port 6006

同時在瀏覽器中轉到 localhost:6006。如果執行異常,可以嘗試使用 ssh 進行埠轉發。

對於使用 GPU 加速渲染,確保在計算機上安裝了 EGL 並設定了 export MUJOCO_GL = egl。


相關文章