強化學習 6 ——價值函式逼近

jsfantasy發表於2020-09-06

上篇文章強化學習——時序差分 (TD) 控制演算法 Sarsa 和 Q-Learning我們主要介紹了 Sarsa 和 Q-Learning 兩種時序差分控制演算法,在這兩種演算法內部都要維護一張 Q 表格,對於小型的強化學習問題是非常靈活高效的。但是在狀態和可選動作非常多的問題中,這張Q表格就變得異常巨大,甚至超出記憶體,而且查詢效率極其低下,從而限制了時序差分的應用場景。近些年來,隨著神經網路的興起,基於深度學習的強化學習稱為了主流,也就是深度強化學習(DRL)。

一、函式逼近介紹

我們知道限制 Sarsa 和 Q-Learning 的應用場景原因是需要維護一張巨大的 Q 表格,那麼我們能不能用其他的方式來代替 Q表格呢?很自然的,就想到了函式。

\[\hat{v}(s, w) \approx v_\pi(s) \\ \hat{q}(s,a, w) \approx q_\pi(s, a) \\ \hat{\pi}{a,s,w} \approx \pi(a|s) \]

也就是說我們可以用一個函式來代替 Q 表格,不斷更新 \(q(s,a)\) 的過程就可以轉化為用引數來擬合逼近真實 q 值的過程。這樣學習的過程不是更新 Q 表格,而是更新 引數 w 的過程。

UHzRUg.png

下面是幾種不同的擬合方式:

第一種函式接受當前的 狀態 S 作為輸入,輸出擬合後的價值函式

第二種函式同時接受 狀態 S 和 動作 a 作為輸入,輸出擬合後的動作價值函式

第三種函式接受狀態 S,輸出每個動作對應的動作價值函式 q

常見逼近函式有線性特徵組合方式、神經網路、決策樹、最近鄰等,在這裡我們只討論可微分的擬合函式:線性特徵組合和神經網路兩種方式。

1、知道真實 V 的函式逼近

對於給定的一個狀態 S 我們假定我們知道真實的 \(v_\pi(s)\) ,然後我們經過擬合得到 \(\hat{v}(s, w)\) ,於是我們就可以使用均方差來計算損失

\[J(w) = E_\pi[(v_\pi(s) - \hat{v}(s, w))^2] \]

利用梯度下降去找到區域性最小值:

\[\Delta w = -\frac{1}{2}\alpha \nabla_wJ(w) \\ w_{t+1} = w_t + \Delta w \]

我們可以提取一些特徵向量來表示當前的 狀態 S,比如對於 gym 的 CartPole 環境,我們可提取的特徵有推車的位置、推車的速度、木杆的角度、木杆的角速度等

UHz2VS.png $$ x(s) = (x_1(s), x_2(s), \cdots,x_n(s))^T $$
此時價值函式 就可以用線性特徵組合表示:

\[\hat{v}(s,w) = x(s)^Tw=\sum_{j=1}^nx_j(s)\cdot w_j \]

此時的損失函式為:

\[J(w) = E_\pi[(v_\pi(s) - x(s)^T w)^2] \]

因此更新規則為:

\[\Delta w = \alpha(v_\pi(s)-\hat{v}(s,w))\cdot x(s) \\ Update = StepSize\;*\;PredictionError\;*\;FeatureValue \]

二、預測過程中的價值函式逼近

因為我們函式逼近的就是 真實的狀態價值,所以在實際的強化學習問題中是沒有 \(v_\pi(s)\) 的,只有獎勵。所以在函式逼近過程的監督資料為:

\[<S_1, G_1>, <S_2, G_2>, \cdots ,<S_t, G_T> \]

所以對於蒙特卡洛我們有:

\[\Delta w = \alpha({\color{red}G_t} - \hat{v}(s_t, w))\nabla_w\hat{v}(s_t, w) \\ = \alpha({\color{red}G_t} - \hat{v}(s_t, w)) \cdot x(s_t) \]

其中獎勵 \(G_t\) 是無偏(unbiased)的:\(E[G_t] = v_\pi(s_t)\) 。值得一提的是,蒙特卡洛預測過程的函式逼近線上性或者是非線性都能收斂。

對於TD演算法,我們使用 \(\hat{v}(s_t, w)\) 來代替 TD Target。所以我們在價值函式逼近(VFA)使用的訓練資料如下所示:

\[<S_1, R_2+\gamma \hat{v}(s_2, w)>,<S_2, R_3+\gamma \hat{v}(s_3, w)>,\cdots,<S_{T-1}, R_T> \]

於是對於 TD(0) 在預測過程的函式逼近有:

\[\Delta w = \alpha({\color{red}R_{t+1} + \gamma \hat{v}(s_{t+1}, w)}-\hat{v}(s_t, w))\nabla_w\hat{v}(s_t, w) \\ = \alpha({\color{red}R_{t+1} + \gamma \hat{v}(s_{t+1}, w)}-\hat{v}(s_t, w))\cdot x(s) \]

因為TD中的 Target 中包含了預測的 \(\hat{v}(s,t)\) ,所以它對於真實的 \(v_\pi(s_t)\) 是有偏(biased)的,因為我們的監督資料是我們估計出來的,而不是真實的資料。也就是 \(E[R_{t+1} + \gamma \hat{v}(s_{t+1}, w)] \neq v_\pi(s_t)\) 。我們把這個過程叫做 semi-gradient,不是完全的梯度下降,而是忽略了權重向量 w 對 Target 的影響。

三、控制過程中的價值函式逼近

類比於MC 和 TD 在使用 Q 表格時的更新公式,對於策略控制過程我們可以得到如下公式。和上面預測過程一樣,我們沒有真實的 \(q_\pi(s,a)\) ,所以我們對其進行了替代:

  • 對於 MC,Target 是 \(G_t\)

\[\Delta w = \alpha({\color{red}G_t} - \hat{q}(s_t, a_t, w))\nabla_w\hat{v}(s_t, a_t, w) \]
  • 對於 Sarsa,TD Target 是 \(R_{t+1} + \gamma \hat{q}(s_{t+1}, a_{t+1}, w)\) :

\[\Delta w = \alpha ({\color{red}R_{t+1} + \gamma \hat{q}(s_{t+1}, a_{t+1}, w)} - \hat{q}{(s_t, s_t, w)})\cdot \nabla_w\hat{q}{(s_t, a_t, w)} \]
  • 對於 Q-Learning,TD Target 是 \(R_{t+1} + \gamma\;max_a\; \hat{q}(s_{t+1}, a_t, w)\) :

\[\Delta w = \alpha ({\color{red}R_{t+1} + \gamma\;max_a\; \hat{q}(s_{t+1}, a_t, w)} - \hat{q}{(s_t, s_t, w)})\cdot \nabla_w\hat{q}{(s_t, a_t, w)} \]

四、關於收斂的問題

UbwGbd.png

在上圖中,對於使用 Q 表格的問題,不管是MC還是 Sarsa 和 Q-Learning 都能找到最優狀態價值。如果是一個大規模的環境,我們採用線性特徵擬合,其中MC 和 Sarsa 是可以找到一個近似最優解的。當使用非線性擬合(如神經網路),這三種演算法都很難保證能找到一個最優解。

其實對於off-policy 的TD Learning強化學習過程收斂是很困難的,主要有以下原因:

  • 使用函式估計:對於 Sarsa 和 Q-Learning 中價值函式的的近似,其監督資料 Target 是不等於真實值的,因為TD Target 中包含了需要優化的 引數 w,也叫作 半梯度TD,其中會存在誤差。
  • Bootstrapping:在更新式子中,上面紅色字型過程中有 貝爾曼近似過程,也就是使用之前的估計來估計當前的函式,這個過程中也引入了不確定因素。(在這個過程中MC回比TD好一點,因為MC中代替 Target 的 \(G_t\) 是無偏的)。
  • Off-policy 訓練:對於 off-policy 策略控制過程中,我們使用 behavior policy 來採集資料,在優化的時候使用另外的 target policy 策略來優化,兩種不同的策略會導致價值函式的估計變的很不準確。

上面三個因素就導致了強化學習訓練的死亡三角,也是強化學習相對於監督學習訓練更加困難的原因。

下一篇就來介紹本系列的第一個深度強化學習演算法 Deep Q-Learning(DQN)

參考資料:

相關文章