強化學習(十一) Prioritized Replay DQN

劉建平Pinard發表於2018-10-16

    在強化學習(十)Double DQN (DDQN)中,我們講到了DDQN使用兩個Q網路,用當前Q網路計算最大Q值對應的動作,用目標Q網路計算這個最大動作對應的目標Q值,進而消除貪婪法帶來的偏差。今天我們在DDQN的基礎上,對經驗回放部分的邏輯做優化。對應的演算法是Prioritized Replay DQN。

    本章內容主要參考了ICML 2016的deep RL tutorial和Prioritized Replay DQN的論文<Prioritized Experience Replay>(ICLR 2016)。

1. Prioritized Replay DQN之前演算法的問題

    在Prioritized Replay DQN之前,我們已經討論了很多種DQN,比如Nature DQN, DDQN等,他們都是通過經驗回放來取樣,進而做目標Q值的計算的。在取樣的時候,我們是一視同仁,在經驗回放池裡面的所有的樣本都有相同的被取樣到的概率。

    但是注意到在經驗回放池裡面的不同的樣本由於TD誤差的不同,對我們反向傳播的作用是不一樣的。TD誤差越大,那麼對我們反向傳播的作用越大。而TD誤差小的樣本,由於TD誤差小,對反向梯度的計算影響不大。在Q網路中,TD誤差就是目標Q網路計算的目標Q值和當前Q網路計算的Q值之間的差距。

    這樣如果TD誤差的絕對值$|\delta(t)|$較大的樣本更容易被取樣,則我們的演算法會比較容易收斂。下面我們看看Prioritized Replay DQN的演算法思路。

2.  Prioritized Replay DQN演算法的建模

    Prioritized Replay DQN根據每個樣本的TD誤差絕對值$|\delta(t)|$,給定該樣本的優先順序正比於$|\delta(t)|$,將這個優先順序的值存入經驗回放池。回憶下之前的DQN演算法,我們僅僅只儲存和環境互動得到的樣本狀態,動作,獎勵等資料,沒有優先順序這個說法。

    由於引入了經驗回放的優先順序,那麼Prioritized Replay DQN的經驗回放池和之前的其他DQN演算法的經驗回放池就不一樣了。因為這個優先順序大小會影響它被取樣的概率。在實際使用中,我們通常使用SumTree這樣的二叉樹結構來做我們的帶優先順序的經驗回放池樣本的儲存。

    具體的SumTree樹結構如下圖:

    所有的經驗回放樣本只儲存在最下面的葉子節點上面,一個節點一個樣本。內部節點不儲存樣本資料。而葉子節點除了儲存資料以外,還要儲存該樣本的優先順序,就是圖中的顯示的數字。對於內部節點每個節點只儲存自己的兒子節點的優先順序值之和,如圖中內部節點上顯示的數字。

    這樣儲存有什麼好處呢?主要是方便取樣。以上面的樹結構為例,根節點是42,如果要取樣一個樣本,那麼我們可以在[0,42]之間做均勻取樣,取樣到哪個區間,就是哪個樣本。比如我們取樣到了26, 在(25-29)這個區間,那麼就是第四個葉子節點被取樣到。而注意到第三個葉子節點優先順序最高,是12,它的區間13-25也是最長的,會比其他節點更容易被取樣到。

    如果要取樣兩個樣本,我們可以在[0,21],[21,42]兩個區間做均勻取樣,方法和上面取樣一個樣本類似。

    類似的取樣演算法思想我們在word2vec原理(三) 基於Negative Sampling的模型第四節中也有講到。

    除了經驗回放池,現在我們的Q網路的演算法損失函式也有優化,之前我們的損失函式是:$$\frac{1}{m}\sum\limits_{j=1}^m(y_j-Q(\phi(S_j),A_j,w))^2$$

    現在我們新的考慮了樣本優先順序的損失函式是$$\frac{1}{m}\sum\limits_{j=1}^mw_j(y_j-Q(\phi(S_j),A_j,w))^2$$

    其中$w_j$是第j個樣本的優先順序權重,由TD誤差$|\delta(t)|$歸一化得到。

    第三個要注意的點就是當我們對Q網路引數進行了梯度更新後,需要重新計算TD誤差,並將TD誤差更新到SunTree上面。

    除了以上三個部分,Prioritized Replay DQN和DDQN的演算法流程相同。

3. Prioritized Replay DQN演算法流程

    下面我們總結下Prioritized Replay DQN的演算法流程,基於上一節的DDQN,因此這個演算法我們應該叫做Prioritized Replay DDQN。主流程參考論文<Prioritized Experience Replay>(ICLR 2016)。

    演算法輸入:迭代輪數$T$,狀態特徵維度$n$, 動作集$A$, 步長$\alpha$,取樣權重係數$\beta$,衰減因子$\gamma$, 探索率$\epsilon$, 當前Q網路$Q$,目標Q網路$Q'$, 批量梯度下降的樣本數$m$,目標Q網路引數更新頻率$C$, SumTree的葉子節點數$S$。

    輸出:Q網路引數。

    1. 隨機初始化所有的狀態和動作對應的價值$Q$.  隨機初始化當前Q網路的所有引數$w$,初始化目標Q網路$Q'$的引數$w' = w$。初始化經驗回放SumTree的預設資料結構,所有SumTree的S個葉子節點的優先順序$p_j$為1。

    2. for i from 1 to T,進行迭代。

      a) 初始化S為當前狀態序列的第一個狀態, 拿到其特徵向量$\phi(S)$

      b) 在Q網路中使用$\phi(S)$作為輸入,得到Q網路的所有動作對應的Q值輸出。用$\epsilon-$貪婪法在當前Q值輸出中選擇對應的動作$A$

      c) 在狀態$S$執行當前動作$A$,得到新狀態$S'$對應的特徵向量$\phi(S')$和獎勵$R$,是否終止狀態is_end

      d) 將$\{\phi(S),A,R,\phi(S'),is\_end\}$這個五元組存入SumTree

      e) $S=S'$

      f)  從SumTree中取樣$m$個樣本$\{\phi(S_j),A_j,R_j,\phi(S'_j),is\_end_j\}, j=1,2.,,,m$,每個樣本被取樣的概率基於$P(j) = \frac{p_j}{\sum\limits_i(p_i)}$,損失函式權重$w_j = (N*P(j))^{-\beta}/\max_i(w_i)$,計算當前目標Q值$y_j$:$$y_j= \begin{cases} R_j& {is\_end_j\; is \;true}\\ R_j + \gamma Q'(\phi(S'_j),\arg\max_{a'}Q(\phi(S'_j),a,w),w')& {is\_end_j\; is \;false} \end{cases}$$

      g)  使用均方差損失函式$\frac{1}{m}\sum\limits_{j=1}^mw_j(y_j-Q(\phi(S_j),A_j,w))^2$,通過神經網路的梯度反向傳播來更新Q網路的所有引數$w$

      h) 重新計算所有樣本的TD誤差$\delta_j = y_j- Q(\phi(S_j),A_j,w)$,更新SumTree中所有節點的優先順序$p_j = |\delta_j|$

      i) 如果T%C=1,則更新目標Q網路引數$w'=w$

      j) 如果$S'$是終止狀態,當前輪迭代完畢,否則轉到步驟b)

      注意,上述第二步的f步和g步的Q值計算也都需要通過Q網路計算得到。另外,實際應用中,為了演算法較好的收斂,探索率$\epsilon$需要隨著迭代的進行而變小。

4. Prioritized Replay DDQN演算法流程

    下面我們給出Prioritized Replay DDQN演算法的例項程式碼。仍然使用了OpenAI Gym中的CartPole-v0遊戲來作為我們演算法應用。CartPole-v0遊戲的介紹參見這裡。它比較簡單,基本要求就是控制下面的cart移動使連線在上面的pole保持垂直不倒。這個任務只有兩個離散動作,要麼向左用力,要麼向右用力。而state狀態就是這個cart的位置和速度, pole的角度和角速度,4維的特徵。堅持到200分的獎勵則為過關。

    完整的程式碼參見我的github: https://github.com/ljpzzz/machinelearning/blob/master/reinforcement-learning/ddqn_prioritised_replay.py, 程式碼中的SumTree的結構和經驗回放池的結構參考了morvanzhou的github程式碼

    這裡重點講下和第三節中演算法描述不同的地方,主要是$w_j$的計算。注意到:$$w_j = \frac{ (N*P(j))^{-\beta}}{\max_i(w_i)} =  \frac{ (N*P(j))^{-\beta}}{\max_i((N*P(i))^{-\beta})} =  \frac{ (P(j))^{-\beta}}{\max_i((P(i))^{-\beta})} =( \frac{p_j}{\min_iP(i)})^{-\beta}$$

    因此程式碼裡面$w_j$,即ISWeights的計算程式碼是這樣的:

    def sample(self, n):
        b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))
        pri_seg = self.tree.total_p / n       # priority segment
        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])  # max = 1

        min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p     # for later calculate ISweight
        if min_prob == 0:
            min_prob = 0.00001
        for i in range(n):
            a, b = pri_seg * i, pri_seg * (i + 1)
            v = np.random.uniform(a, b)
            idx, p, data = self.tree.get_leaf(v)
            prob = p / self.tree.total_p
            ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
            b_idx[i], b_memory[i, :] = idx, data
        return b_idx, b_memory, ISWeights

    上述程式碼的取樣在第二節已經講到。根據樹的優先順序的和total_p和取樣數n,將要取樣的區間劃分為n段,每段來進行均勻取樣,根據取樣到的值落到的區間,決定被取樣到的葉子節點。當我們拿到第i段的均勻取樣值v以後,就可以去SumTree中找對應的葉子節點拿樣本資料,樣本葉子節點序號以及樣本優先順序了。程式碼如下:

    def get_leaf(self, v):
        """
        Tree structure and array storage:
        Tree index:
             0         -> storing priority sum
            / \
          1     2
         / \   / \
        3   4 5   6    -> storing priority for transitions
        Array type for storing:
        [0,1,2,3,4,5,6]
        """
        parent_idx = 0
        while True:     # the while loop is faster than the method in the reference code
            cl_idx = 2 * parent_idx + 1         # this leaf's left and right kids
            cr_idx = cl_idx + 1
            if cl_idx >= len(self.tree):        # reach bottom, end search
                leaf_idx = parent_idx
                break
            else:       # downward search, always search for a higher priority node
                if v <= self.tree[cl_idx]:
                    parent_idx = cl_idx
                else:
                    v -= self.tree[cl_idx]
                    parent_idx = cr_idx

        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    除了取樣部分,要注意的就是當梯度更新完畢後,我們要去更新SumTree的權重,程式碼如下,注意葉子節點的權重更新後,要向上回溯,更新所有祖先節點的權重。

    self.memory.batch_update(tree_idx, abs_errors)  # update priority
    def batch_update(self, tree_idx, abs_errors):
        abs_errors += self.epsilon  # convert to abs and avoid 0
        clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
        ps = np.power(clipped_errors, self.alpha)
        for ti, p in zip(tree_idx, ps):
            self.tree.update(ti, p)
    def update(self, tree_idx, p):
        change = p - self.tree[tree_idx]
        self.tree[tree_idx] = p
        # then propagate the change through tree
        while tree_idx != 0:    # this method is faster than the recursive loop in the reference code
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    除了上面這部分的區別,和DDQN比,TensorFlow的網路結構流程中多了一個TD誤差的計算節點,以及損失函式多了一個ISWeights係數。此外,區別不大。

5. Prioritized Replay DQN小結

    Prioritized Replay DQN和DDQN相比,收斂速度有了很大的提高,避免了一些沒有價值的迭代,因此是一個不錯的優化點。同時它也可以直接整合DDQN演算法,所以是一個比較常用的DQN演算法。

    下一篇我們討論DQN家族的另一個優化演算法Duel DQN,它將價值Q分解為兩部分,第一部分是僅僅受狀態但不受動作影響的部分,第二部分才是同時受狀態和動作影響的部分,演算法的效果也很好。

 

(歡迎轉載,轉載請註明出處。歡迎溝通交流: liujianping-ok@163.com)

相關文章