強化學習-學習筆記7 | Sarsa演算法原理與推導

climerecho發表於2022-07-07

Sarsa演算法 是 TD演算法的一種,之前沒有嚴謹推導過 TD 演算法,這一篇就來從數學的角度推導一下 Sarsa 演算法。注意,這部分屬於 TD演算法的延申。

7. Sarsa演算法

7.1 推導 TD target

推導:Derive。

這一部分就是Sarsa 最重要的核心。

折扣回報:$U_t=R_t+\gamma R_{t+1}+\gamma^2 R_{t+2}+\gamma^3 R_{t+3}+\cdots \ \quad={R_t} + \gamma \cdot U_{t+1} $

即 將\(R_{t+1}\)之後 都提出一個 \(\gamma\) 項,後面括號中的式子意義正為 \(U_{t+1}\)

通常認為獎勵 \(R_t\)依賴於 t 時刻的狀態 \(S_t\) 與 動作 \(A_t\) 以及 t+1 時刻的狀態 \(S_{t+1}\)

當時對於為什麼依賴於 \(S_{t+1}\) 有疑問,我回去翻看了 學習筆記1:https://www.cnblogs.com/Roboduster/p/16442003.html ,發現並強調了以下這一點:

“值得注意的是,這個 r1 是什麼時候給的?是在狀態 state s2 的時候給的。”

狀態價值函式 \(Q_\pi({s_t},{a_t}) = \mathbb{E}[U_t|{s_t},{a_t}]\) 是回報 \(U_t\) 的期望;

  • 用折扣回報的變換式,把\(U_t\)替換掉:\(Q_\pi({s_t},{a_t}) = \mathbb{E}[{R_t} + \gamma \cdot U_{t+1} |{s_t}{a_t}]\)
  • 有兩項期望,分解開:\(= \mathbb{E}[{R_t} |{s_t},{a_t}] + \gamma \cdot\mathbb{E}[ U_{t+1} |{s_t},{a_t}]\)

下面研究上式的第二項:\(\mathbb{E}[ U_{t+1} |{s_t},{a_t}]\)

其等於 \(\mathbb{E}[ Q_\pi({s_{t+1}},{a_{t+1}}) |{s_t},{a_t}]\)

Q 是 U 的期望:所以 \(E(E[])=E()\),期望的期望還是原來的期望;這裡是逆用這個性質。這麼做是為了讓等式兩邊都有 \(Q_\pi\) 函式,如下:

於是便得到: \(Q_\pi({s_t},{a_t}) =\mathbb{E}[{R_t} |{s_t},{a_t}] + \gamma\cdot\mathbb{E}[ Q_\pi({s_{t+1}},{a_{t+1}}) {s_t},{a_t}] \\ Q_\pi({s_t},{a_t})=\mathbb{E}[{R_t} + \gamma \cdot Q_\pi({S_{t+1}},{A_{t+1}})]\)

右側有一個期望,但直接求期望很困難,所以通常是對期望求蒙特卡洛近似。

  1. \(R_t\) 近似為觀測到獎勵\(r_t\)
  2. \(Q_\pi({S_{t+1}},{A_{t+1}})\)用觀測到的 \(Q_\pi({s_{t+1}},{a_{t+1}})\) 來近似
  3. 得到蒙特卡洛近似值\(\approx {r_t} + \gamma \cdot Q_\pi({s_{t+1}},{a_{t+1}})\)
  4. 將這個值表示為 TD target \(y_t\)

TD learning 目標:讓 $Q_\pi({s_t},{a_t}) $ 來接近部分真實的獎勵 \(y_t\)

\(Q_\pi\) 完全是估計,而 \(y_t\) 包含了一部分真實獎勵,所以 \(y_t\) 更可靠。

7.2 Sarsa演算法過程

這是一種TD 演算法。

a. 表格形式

如果我們想要學習動作價值 $Q_\pi({s_t},{a_t}) $,假設狀態和動作都是有限的,可以畫一個表來表示:

  1. 表每個元素代表一個動作價值;
  2. 用 Sarsa 演算法更新表格,每次更新一個元素;
  • 在表格形式中,每次觀測到一個四元組\(({s_t},{a_t},{r_t},{s_{t+1}})\),稱為一個 transition

  • 根據策略函式 \(\pi\) 隨機取樣計算下一個動作,記作\({a_{t+1}}\sim\pi(\cdot|{s_{t+1}})\)

  • 計算TD target: \(y_t = {r_t} + \gamma \cdot Q_\pi({s_{t+1}},{a_{t+1}})\)

    前一部分是觀測到的獎勵,後面一部分是對未來動作的打分,\(Q_\pi({s_{t+1}},{a_{t+1}})\) 可以通過查表得知。

    表最開始是通過一定方式初始化的(比如隨機),然後通過不斷計算來更新表格。

    通過查表,還知道\(Q_\pi({s_{t}},{a_{t}})\)的值,可以計算:

  • TD error:\(\delta_t = Q_\pi({s_{t}},{a_{t}}) -y_t\)

  • 最後用 \(\delta_t\) 來更新:\(Q_\pi({s_{t}},{a_{t}}) \leftarrow Q_\pi({s_{t}},{a_{t}}) - \alpha \cdot \delta_t\),並寫入表格相應的位置

    $\alpha $是學習率。通過TD error 更新,可以讓 Q 更好的接近 \(y_t\)

每一步中,Sarsa 演算法用 \((s_t,a_T,r_t,s_{t+1},a_{t+1})\) 來更新 \(Q_\pi\),sarsa,這就是演算法名字的由來。

b. 神經網路形式

值得留意的是表格形式的假設:假設狀態和動作都是有限的,而當狀態和動作很多,表格就會很大,很難學習。

  • 用神經網路-價值網路 \(q({s},{a};w)\) 來近似\(Q_\pi({s},{a})\),Sarsa演算法可以訓練這個價值網路。

    1. actor-critic 那篇用過 Sarsa 演算法,想不起來往下看:
    2. q 和 Q 都與 策略函式 \(\pi\) 有關。
    3. 網路引數 \(\omega\) 初始時隨機初始化,後續不斷更新。

輸入狀態是 s ,輸出就是所有動作的價值

  • actor-critic 方法中,q 作為 critic 用來評估 actor;用 sarsa 這一 TD 學習演算法更新的價值網路。
  • TD target: \(y_t = {r_t} + \gamma \cdot q({s_{t+1}},{a_{t+1}};w)\)
  • TD error:\(\delta_t = q({s_{t}},{a_{t}};w) - y_t\)
  • Loss: \(\delta_t ^2/2\),我們的目的是通過更新網路引數 w 來降低 Loss;
  • 梯度:\(\frac{\partial\delta_t ^2/2}{\partial w} = \delta_t \cdot \frac{\partial q({s_{t}},{a_{t}};w)}{\partial w}\)
  • 梯度下降更新 w:$$w \leftarrow w - \alpha \cdot \delta_t \cdot \frac{\partial q({s_{t}},{a_{t}};w)}{\partial w}$$

7.3 一些解惑 / 有什麼不同

這一篇跟第二篇價值學習內容看似很接近,甚至在第四篇 actor-critic 中也有提及,可能會困惑 這個第七篇有什麼特別的,我也困惑了一會兒,然後我發現是自己的學習不夠仔細:

第二篇和第四篇的 價值網路 學習方法並不同。雖然都用到了 以TD target 為代表的TD 演算法。但是兩者的學習函式並不相同!

  1. Sarsa演算法 學習動作價值函式 \(Q_\pi(s,a)\)

  2. Actor-Critic 中的價值網路j就是用 Sarsa 訓練的

  3. 而第二篇 DQN 中的 TD 學習 是訓練最優動作價值函式:
    $Q ^*( s , a ) $

    而這種方法在下一篇中很快會提及,這就是 Q-learning 方法。

參考:

TD演算法總述

Sarsa演算法及其程式碼

相關文章