Sarsa演算法 是 TD演算法的一種,之前沒有嚴謹推導過 TD 演算法,這一篇就來從數學的角度推導一下 Sarsa 演算法。注意,這部分屬於 TD演算法的延申。
7. Sarsa演算法
7.1 推導 TD target
推導:Derive。
這一部分就是Sarsa 最重要的核心。
折扣回報:$U_t=R_t+\gamma R_{t+1}+\gamma^2 R_{t+2}+\gamma^3 R_{t+3}+\cdots \ \quad={R_t} + \gamma \cdot U_{t+1} $
即 將\(R_{t+1}\)之後 都提出一個 \(\gamma\) 項,後面括號中的式子意義正為 \(U_{t+1}\)
通常認為獎勵 \(R_t\)依賴於 t 時刻的狀態 \(S_t\) 與 動作 \(A_t\) 以及 t+1 時刻的狀態 \(S_{t+1}\)。
當時對於為什麼依賴於 \(S_{t+1}\) 有疑問,我回去翻看了 學習筆記1:https://www.cnblogs.com/Roboduster/p/16442003.html ,發現並強調了以下這一點:
“值得注意的是,這個 r1 是什麼時候給的?是在狀態 state s2 的時候給的。”
狀態價值函式 \(Q_\pi({s_t},{a_t}) = \mathbb{E}[U_t|{s_t},{a_t}]\) 是回報 \(U_t\) 的期望;
- 用折扣回報的變換式,把\(U_t\)替換掉:\(Q_\pi({s_t},{a_t}) = \mathbb{E}[{R_t} + \gamma \cdot U_{t+1} |{s_t}{a_t}]\)
- 有兩項期望,分解開:\(= \mathbb{E}[{R_t} |{s_t},{a_t}] + \gamma \cdot\mathbb{E}[ U_{t+1} |{s_t},{a_t}]\)
下面研究上式的第二項:\(\mathbb{E}[ U_{t+1} |{s_t},{a_t}]\)
其等於 \(\mathbb{E}[ Q_\pi({s_{t+1}},{a_{t+1}}) |{s_t},{a_t}]\)
Q 是 U 的期望:所以 \(E(E[])=E()\),期望的期望還是原來的期望;這裡是逆用這個性質。這麼做是為了讓等式兩邊都有 \(Q_\pi\) 函式,如下:
於是便得到: \(Q_\pi({s_t},{a_t}) =\mathbb{E}[{R_t} |{s_t},{a_t}] + \gamma\cdot\mathbb{E}[ Q_\pi({s_{t+1}},{a_{t+1}}) {s_t},{a_t}] \\ Q_\pi({s_t},{a_t})=\mathbb{E}[{R_t} + \gamma \cdot Q_\pi({S_{t+1}},{A_{t+1}})]\)
右側有一個期望,但直接求期望很困難,所以通常是對期望求蒙特卡洛近似。
- \(R_t\) 近似為觀測到獎勵\(r_t\)
- \(Q_\pi({S_{t+1}},{A_{t+1}})\)用觀測到的 \(Q_\pi({s_{t+1}},{a_{t+1}})\) 來近似
- 得到蒙特卡洛近似值\(\approx {r_t} + \gamma \cdot Q_\pi({s_{t+1}},{a_{t+1}})\)
- 將這個值表示為 TD target \(y_t\)
TD learning 目標:讓 $Q_\pi({s_t},{a_t}) $ 來接近部分真實的獎勵 \(y_t\)。
\(Q_\pi\) 完全是估計,而 \(y_t\) 包含了一部分真實獎勵,所以 \(y_t\) 更可靠。
7.2 Sarsa演算法過程
這是一種TD 演算法。
a. 表格形式
如果我們想要學習動作價值 $Q_\pi({s_t},{a_t}) $,假設狀態和動作都是有限的,可以畫一個表來表示:
- 表每個元素代表一個動作價值;
- 用 Sarsa 演算法更新表格,每次更新一個元素;
-
在表格形式中,每次觀測到一個四元組\(({s_t},{a_t},{r_t},{s_{t+1}})\),稱為一個 transition
-
根據策略函式 \(\pi\) 隨機取樣計算下一個動作,記作\({a_{t+1}}\sim\pi(\cdot|{s_{t+1}})\);
-
計算TD target: \(y_t = {r_t} + \gamma \cdot Q_\pi({s_{t+1}},{a_{t+1}})\),
前一部分是觀測到的獎勵,後面一部分是對未來動作的打分,\(Q_\pi({s_{t+1}},{a_{t+1}})\) 可以通過查表得知。
表最開始是通過一定方式初始化的(比如隨機),然後通過不斷計算來更新表格。
通過查表,還知道\(Q_\pi({s_{t}},{a_{t}})\)的值,可以計算:
-
TD error:\(\delta_t = Q_\pi({s_{t}},{a_{t}}) -y_t\);
-
最後用 \(\delta_t\) 來更新:\(Q_\pi({s_{t}},{a_{t}}) \leftarrow Q_\pi({s_{t}},{a_{t}}) - \alpha \cdot \delta_t\),並寫入表格相應的位置
$\alpha $是學習率。通過TD error 更新,可以讓 Q 更好的接近 \(y_t\)。
每一步中,Sarsa 演算法用 \((s_t,a_T,r_t,s_{t+1},a_{t+1})\) 來更新 \(Q_\pi\),sarsa,這就是演算法名字的由來。
b. 神經網路形式
值得留意的是表格形式的假設:假設狀態和動作都是有限的,而當狀態和動作很多,表格就會很大,很難學習。
-
用神經網路-價值網路 \(q({s},{a};w)\) 來近似\(Q_\pi({s},{a})\),Sarsa演算法可以訓練這個價值網路。
- actor-critic 那篇用過 Sarsa 演算法,想不起來往下看:
- q 和 Q 都與 策略函式 \(\pi\) 有關。
- 網路引數 \(\omega\) 初始時隨機初始化,後續不斷更新。
輸入狀態是 s ,輸出就是所有動作的價值
- actor-critic 方法中,q 作為 critic 用來評估 actor;用 sarsa 這一 TD 學習演算法更新的價值網路。
- TD target: \(y_t = {r_t} + \gamma \cdot q({s_{t+1}},{a_{t+1}};w)\)
- TD error:\(\delta_t = q({s_{t}},{a_{t}};w) - y_t\)
- Loss: \(\delta_t ^2/2\),我們的目的是通過更新網路引數 w 來降低 Loss;
- 梯度:\(\frac{\partial\delta_t ^2/2}{\partial w} = \delta_t \cdot \frac{\partial q({s_{t}},{a_{t}};w)}{\partial w}\)
- 梯度下降更新 w:$$w \leftarrow w - \alpha \cdot \delta_t \cdot \frac{\partial q({s_{t}},{a_{t}};w)}{\partial w}$$
7.3 一些解惑 / 有什麼不同
這一篇跟第二篇價值學習內容看似很接近,甚至在第四篇 actor-critic 中也有提及,可能會困惑 這個第七篇有什麼特別的,我也困惑了一會兒,然後我發現是自己的學習不夠仔細:
第二篇和第四篇的 價值網路 學習方法並不同。雖然都用到了 以TD target 為代表的TD 演算法。但是兩者的學習函式並不相同!
Sarsa演算法 學習動作價值函式 \(Q_\pi(s,a)\)
Actor-Critic 中的價值網路j就是用 Sarsa 訓練的
而第二篇 DQN 中的 TD 學習 是訓練最優動作價值函式:
$Q ^*( s , a ) $而這種方法在下一篇中很快會提及,這就是 Q-learning 方法。
參考: