論文結果難復現?本文教你完美實現深度強化學習演算法DQN

機器之心發表於2017-11-24
論文的復現一直是很多研究者和開發者關注的重點,近日有研究者詳細論述了他們在復現深度 Q 網路所踩過的坑與訓練技巧。本論文不僅重點標註了實現過程中的終止條件和優化演算法等關鍵點,同時還討論了實現的效能改進方案。機器之心簡要介紹了該論文,更詳細的實現細節請檢視原論文。

過去幾年來,深度強化學習逐漸流行,因為它在有超大狀態空間(state-spaces)的領域上要比先前的方法有更好的表現。DQN 幾乎在所有的遊戲上超越了之前的強化學習方法,並在大部分遊戲上比人類表現更好。隨著更多的研究人員用深度強化學習方法解決強化學習問題,並提出替代性演算法,DQN 論文的結果經常被用作展示進步的基準。因此,實現 DQN 演算法對復現 DQN 論文結果和構建原演算法都很重要。

我們部署了一個 DQN 來玩 Atari 遊戲並重復 Mnih 等人的結果。我們的實現要比原始實現快 4 倍,且已經在網上開源。此外,該實現在設計上,對不同的神經網路架構、ALE 之外領域也更為靈活。在重複這些結果時,我們發現實現這些系統的過程的幾個關鍵。在這篇論文中,我們強調了一些關鍵的技術,這些技術對於獲得優良的效能和重複 Mnih 等人的結果是很基本的,其中包括了終止條件和梯度下降優化演算法,以及演算法的期望結果(也就是網路的效能波動)。

論文:Implementing the Deep Q-Network

論文結果難復現?本文教你完美實現深度強化學習演算法DQN

論文地址:https://arxiv.org/abs/1711.07478

Mnih 等人在 2015 年提出的深度 Q 網路已經成為了一項基準,也是許多深度強化學習研究的基點。然而,復現複雜系統的結果總是非常難,因為最初的文獻經常無法詳細描述每個重要的引數和軟體工程的解決方案。在此論文中,我們復現了 DQN 的論文結果。此外,我們重點標註了實現過程中的關鍵點,從而讓研究人員能更容易地復現結果,包括終止條件、梯度下降演算法等。而這些點是原論文沒有詳細描述的。最後,我們討論了改進計算效能的方法,並給出我們的實現,該實現可廣泛應用,而不是隻能在原論文中的 Arcade 學習環境(ALE)中實現。

3 深度 Q 學習

深度 Q 學習(DQN)是經典 Q 學習演算法的變體,有 3 個主要貢獻:(1)深度卷積神經網路架構用於 Q 函式近似;(2)使用小批量隨機訓練資料而不是在上一次經驗上進行一步更新;(3)使用舊的網路引數來評估下一個狀態的 Q 值。DQN 的虛擬碼(複製自 Mnih et al. [2015])見演算法 1。深度卷積架構提供一個通用機制從影像幀的短歷史(尤其是最後 4 幀)中評估 Q 函式的值。後面兩個貢獻主要關於如何使迭代的 Q 函式估計保持穩定。

論文結果難復現?本文教你完美實現深度強化學習演算法DQN

監督式深度學習研究中,在小批量資料上執行梯度下降通常是一種高效訓練網路的方式。在 DQN 中,它扮演了另外一個角色。具體來說,DQN 儲存大量最近經驗的歷史,每個經驗有五個元組(s, a, s', r, T):智慧體在狀態 s 執行動作 a,然後到達狀態 s',收到獎勵 r;T 是一個布林值,指示 s'是否為最終狀態。在環境中的每一步之後,智慧體新增經驗至記憶體。在少量步之後(DQN 論文使用了 4 步),智慧體從記憶體中進行小批量隨機取樣,然後在上面執行 Q 函式更新。在 Q 函式更新中重用先前的經驗叫作經驗回放(experience replay)[Lin, 1992]。但是,儘管強化學習中的經驗回放通常用於加快獎勵備份(backup of rewards),DQN 從記憶體中進行小批量完全隨機取樣有助於去除樣本和環境的相關性,否則容易引起函式近似估計中出現偏差。

最終的主要貢獻是使用舊的網路引數來評估一個經驗中下一個狀態的 Q 值,且只在離散的多步間隔(many-step interval)上更新舊的網路引數。該方法對 DQN 很有用,因為它為待擬合的網路函式提供了一個穩定的訓練目標,並給予充分的訓練時間(根據訓練樣本數量決定)。因此,估計誤差得到了更好地控制。

儘管這些貢獻和整體演算法在概念層面上是很直接的,但要想達到 Mnih et al. [2015] 報告中相同的效能水平需要考慮大量重要細節,設計者必須牢記學習過程的重要特性。下文將具體描述細節。

3.1 實現細節

由於原始的科研文獻經常無法提供重要引數設定和軟體工程解決方案的細節,因此,很多大型系統(比如 DQN)都難以實現。因此,DQN 論文並沒有明確地提及或完整地說明一些重要的演算法基礎細節。本文,我們將強調其中一些額外的關鍵實現細節(根據原論文的 DQN 程式碼總結)。

首先,每一個 episode 從隨機數量(0 到 30 之間)的「No-op」低階別 Atari 動作開始(相對於將智慧體的動作(action)重複 4 個幀),以抵消智慧體所看見的幀,這是因為智慧體每次只能看到 4 個 Atari 幀。類似地,用作 CNN 輸入的 m 個幀歷史是智慧體最後看見的 m 個幀,而不是最後的 m 個 Atari 幀。此外,在使用梯度下降迭代之前,我們會執行 50000 步的隨機策略作為補充經驗以避免對早期經驗的過擬合。

另一個值得注意的引數是網路更新頻率(network update frequency)。原始的 DQN 實現僅在演算法的每 4 個環境步驟後執行一個梯度下降步驟,這和演算法 1 截然不同(每一個環境步驟執行一個梯度下降步驟)。這不僅僅大大加快了訓練速度(由於網路學習步驟的計算量比前向傳播大得多),還使得經驗記憶體更加相似於當前策略的狀態分佈(由於訓練步驟之間需要新增 4 個新的幀到記憶體中,這和新增 1 個幀是截然不同的),可能有防止過擬合的作用。

3.2 DQN 的效能波動(fluctuating performance)

圖 1 展示了最佳網路和最差網路(在 Breakout 遊戲的開始階段使用相同的輸入啟動訓練)的 Q 值近似。第一幀展示了這樣的場景:智慧體可以採取任意的動作,都不會使球在經過未來幾個動作之後就掉落。但是在彈回球之前做出的動作也可以幫助智慧體瞄準球的位置。這個例子中兩個網路的 Q 值是很相近的,但是各自選擇的動作是不同的。在第二幀的場景中,假如智慧體沒有采取向左移動的動作,球就會掉落,遊戲終止。在這個例子中,兩個網路的 Q 值差別是很大的。因此,當執行這個演算法的時候,可能會出現這種效能波動。

論文結果難復現?本文教你完美實現深度強化學習演算法DQN

圖 1:經過 3000 萬步的訓練後進行測試,最優和最差網路的 Q 值對比。陰影線區域代表 Q 值最高的動作。最上面一幀對應動作對近期獎勵無顯著影響的情況,底部幀代表必須執行左側動作以免損失生命的情況。「Release」動作指在每局開始的時候釋放球,或當球已經開始運動時什麼也不做(和「無操作」(No-op)一樣)。

5 結果

我們的結果與 DQN 論文關於 Pong、Breakout 和 Seaquest 的結果對比見表 1。我們的實現中每個訓練過程大約用時 3 天,而我們配置的原始實現用時大約 10.5 天。

論文結果難復現?本文教你完美實現深度強化學習演算法DQN

表 1:我們的 DQN 實現和原 DQN 論文獲得的平均遊戲分數的對比。

6 核心訓練技巧

我們在實現 DQN 時,發現了只在 DQN 論文中簡要提及的兩種方法,但是它們對演算法的整體表現至關重要。下面我們將展示這兩種方法,並解釋為什麼它們對網路訓練的影響如此之大。

6.1 掉命終止

絕大多數 Atari 遊戲中,玩家都有幾條「命」,對應遊戲結束之前玩家可以失敗的次數。為了提升表現,Mnih et al. [2015] 選擇在訓練中把生命數的損失(在涉及生命數的遊戲中)作為 MDP 的最終狀態。這一終止條件在 DQN 論文中沒有提及太多,但卻對提升效能至關重要。

圖 2 展示了在 Breakout 和 Seaquest 中,把和不把生命數損失作為最終狀態的區別。在 Breakout 中,使用生命數的結束作為最終狀態的學習器的平均分值增長要遠快於另一個學習器。但是,訓練大約進行一半時,另一個學習器獲得相似表現,卻帶有更高的方差。Seaquest 是一個更為複雜的遊戲,其中使用生命數作為最終狀態的學習器在整個訓練中表現要遠好於另一個學習器。這些圖表明這一額外的先驗資訊非常有利於早期訓練和穩定性,並在更復雜的遊戲中顯著提升了整體表現。

如上所述,MDP 中的最終狀態意味著智慧體無法再獲得更多獎勵。幾乎所有的 Atari 遊戲給出正面獎勵,因此這一附加資訊很關鍵地告知智慧體無論如何都要避免失去生命數,這看起來確實很理性:很多玩家一開始就知道在 Atari 中損失生命數很糟糕,並且很難想象出其中最優策略是失去生命數的場景。

但是,執行該約束存在多個理論問題。首先,由於初始狀態分佈依賴於當前策略,該過程將不再是馬爾科夫性質的。一個相關的例子是在 Breakout 遊戲中:如果智慧體表現很好,在失去一條生命之前破壞了很多磚,則新生命的初始狀態擁有的磚,會比智慧體在上個生命中表現不好、破壞不多磚時更少。另一個問題是該訊號為 DQN 提供了很強的額外資訊,從而使擴充套件至沒有強訊號的領域變得困難(如現實機器人或開放性更強的電子遊戲)。

論文結果難復現?本文教你完美實現深度強化學習演算法DQN

圖 2:Breakout 和 Seaquest 在每個測試集上使用命數和遊戲結束作為最終狀態時分別得到的平均訓練測試分數(epoch = 250,000 steps)。

ALE 為每個遊戲儲存了剩餘生命數,但它沒有向所有介面提供這個資訊。為了解決這個侷限,我們修改了 ALE 的 FIFO 介面以在螢幕上提供剩餘生命數、獎勵和最終狀態布林值的資訊。我們的 fork 在 FIFO 介面上提供了該資料,大家可線上免費訪問。

6.2 梯度下降優化演算法

在使用 Mnih et al. [2015] 所提供的超引數時,我們會遇到一個潛在問題,即原論文並不是直接使用許多深度學習庫(如 Caffe)所定義的 RMSProp 優化演算法。RMSProp 梯度下降優化演算法最初是由 Gerffrey Hinton 所提出來的,Hinton 的 RMSProp 針對每個引數保持一個滑動平均(running average)梯度。這種滑動平均梯度的更新規則可以寫為:

論文結果難復現?本文教你完美實現深度強化學習演算法DQN

其中,w 對應單個網路的引數,γ 為梯度衰減引數,E 為經驗損失。引數的更新過程可以寫為:

論文結果難復現?本文教你完美實現深度強化學習演算法DQN

其中α為學習率,ε為非常小的常量以避免分母為零。

即使 Mnih et al. [2015] 引用了 Hinton 的 RMSProp,但他們使用的最優化演算法仍然略有不同。這個不同點可以在他們的 GitHub 中找到(以下地址),即在 NeuralQLearner.lua 檔案的第 266 行到 273 行程式碼中。該變體將動量因子加入到了 RMSProp 演算法中,因此梯度的更新規則可以寫為:

論文結果難復現?本文教你完美實現深度強化學習演算法DQN

Mnih 等人實現地址:www.github.com/kuz/DeepMind-Atari-Deep-Q-Learner

其中η為動量衰減因素,引數的更新規則可以寫為:

論文結果難復現?本文教你完美實現深度強化學習演算法DQN

為了解決優化演算法中的這種大幅變化,我們必須將學習率修改為遠低於 Mnih et al. [2015] 在實現中設定的學習率,即將他們的 0.00025 修改為 0.00005。我們並沒有選擇實現這種 RMSProp 變體,因為用 Java-Caffe 捆綁包實現是很重要的,且 Hinton 的一般 RMSPorp 演算法產生了類似的效果。

7 效能加速

我們的實現要比原論文使用 Lua 和 Torch 的實現快 4 倍,且測試這些實現的配置是兩張 NVIDIA GTX 980 TI 顯示卡和一個 Intel i7 處理器。我們效能的提升很大部分可以歸因於 cuDNN 庫的幫助,我們在訓練過程中以每秒約 985 Atari 幀(fps)的速度進行,測試中以每秒約 1584 幀(fps)的速度進行。

我們使用了 cuDNN 進行實驗,而 Lua 並沒有在 Torch 中使用該加速庫。為了完成對比,我們在沒有使用 cuDNN 的 Caffe 上訓練和測試時,速度分別為 268fps 和 485fps。這要比原論文 Lua 實現慢一些。

8. 結論

為了讓研究人員更好地實現自己的 DQN,我們在此論文中展現了實現 Mnih 等人提出的 DQN 時的關鍵點,這些關鍵點對此演算法的整體表現極為重要,但在原論文中卻沒有提到,以幫助研究者更容易地實現該演算法的個人版本。我們也重點標註了在災難性遺忘(catastrophic forgetting)這樣的大型狀態空間中用 CNN 逼近 Q 函式時的難點。之後,我們把自己的實現開源到了網上,也鼓勵研究人員使用它實現全新的演算法,並與 Mnih 等人的結果做比較。 

本論文的 GitHub 實現地址:https://github.com/h2r/burlap_caffe

相關文章