強化學習筆記之【SAC演算法】
前言:
本文為強化學習筆記第四篇,第一篇講的是Q-learning和DQN,第二篇DDPG,第三篇TD3
TD3比DDPG少了一個target_actor網路,其它地方有點小改動
CSDN主頁:https://blog.csdn.net/rvdgdsva
部落格園主頁:https://www.cnblogs.com/hassle
- 強化學習筆記之【SAC演算法】
- 前言:
- 一、SAC演算法
- 二、SAC演算法Latex解釋
- 三、SAC五大網路和模組
- 3.1 Actor 網路
- 3.2 Critic1 和 Critic2 網路
- 3.3 Target Critic1 和 Target Critic2 網路
- 3.4 軟更新模組
- 3.5 總結
首先,我們需要明確,Q-learning演算法發展成DQN演算法,DQN演算法發展成為DDPG演算法,而DDPG演算法發展成TD3演算法,TD3演算法發展成SAC演算法
Soft Actor-Critic (SAC) 是一種基於策略梯度的深度強化學習演算法,它具有最大化獎勵與最大化熵(探索性)的雙重目標。SAC 透過引入熵正則項,使策略在決策時具有更大的隨機性,從而提高探索能力。
一、SAC演算法
OK,先用虛擬碼讓你們感受一下SAC演算法
# 定義 SAC 超引數
alpha = 0.2 # 熵正則項係數
gamma = 0.99 # 折扣因子
tau = 0.005 # 目標網路軟更新引數
lr = 3e-4 # 學習率
# 初始化 Actor、Critic、Target Critic 網路和最佳化器
actor = ActorNetwork() # 策略網路 π(s)
critic1 = CriticNetwork() # 第一個 Q 網路 Q1(s, a)
critic2 = CriticNetwork() # 第二個 Q 網路 Q2(s, a)
target_critic1 = CriticNetwork() # 目標 Q 網路 1
target_critic2 = CriticNetwork() # 目標 Q 網路 2
# 將目標 Q 網路的引數設定為與 Critic 網路相同
target_critic1.load_state_dict(critic1.state_dict())
target_critic2.load_state_dict(critic2.state_dict())
# 初始化最佳化器
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=lr)
critic1_optimizer = torch.optim.Adam(critic1.parameters(), lr=lr)
critic2_optimizer = torch.optim.Adam(critic2.parameters(), lr=lr)
# 經驗回放池(Replay Buffer)
replay_buffer = ReplayBuffer()
# SAC 訓練迴圈
for each iteration:
# Step 1: 從 Replay Buffer 中取樣一個批次 (state, action, reward, next_state)
batch = replay_buffer.sample()
state, action, reward, next_state, done = batch
# Step 2: 計算目標 Q 值 (y)
with torch.no_grad():
# 從 Actor 網路中獲取 next_state 的下一個動作
next_action, next_log_prob = actor.sample(next_state)
# 目標 Q 值的計算:使用目標 Q 網路的最小值 + 熵項
target_q1_value = target_critic1(next_state, next_action)
target_q2_value = target_critic2(next_state, next_action)
min_target_q_value = torch.min(target_q1_value, target_q2_value)
# 目標 Q 值 y = r + γ * (最小目標 Q 值 - α * next_log_prob)
target_q_value = reward + gamma * (1 - done) * (min_target_q_value - alpha * next_log_prob)
# Step 3: 更新 Critic 網路
# Critic 1 損失
current_q1_value = critic1(state, action)
critic1_loss = F.mse_loss(current_q1_value, target_q_value)
# Critic 2 損失
current_q2_value = critic2(state, action)
critic2_loss = F.mse_loss(current_q2_value, target_q_value)
# 反向傳播並更新 Critic 網路引數
critic1_optimizer.zero_grad()
critic1_loss.backward()
critic1_optimizer.step()
critic2_optimizer.zero_grad()
critic2_loss.backward()
critic2_optimizer.step()
# Step 4: 更新 Actor 網路
# 透過 Actor 網路生成新的動作及其 log 機率
new_action, log_prob = actor.sample(state)
# 計算 Actor 的目標損失:L = α * log_prob - Q1(s, π(s))
q1_value = critic1(state, new_action)
actor_loss = (alpha * log_prob - q1_value).mean()
# 反向傳播並更新 Actor 網路引數
actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()
# Step 5: 軟更新目標 Q 網路引數
with torch.no_grad():
for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
二、SAC演算法Latex解釋
1、初始化 Actor、Critic1、Critic2、TargetCritic1 、TargetCritic2 網路
2、Buffer中取樣 (state, action, reward, next_state)
3、Actor 輸入 next_state 對應輸出 next_action 和 next_log_prob
4、Actor 輸入 state 對應輸出 new_action 和 log_prob
5、Critic1 和 Critic2 分別輸入next_state 和 next_action 取其中較小輸出經熵正則計算得 target_q_value
6、使用 MSE_loss(Critic1(state, action), target_q_value) 更新 Critic1
7、使用 MSE_loss(Critic2(state, action), target_q_value) 更新 Critic2
8、使用 (alpha * log_prob - critic1(state, new_action)).mean() 更新 Actor
三、SAC五大網路和模組
在 SAC 演算法 中,Actor、Critic1、Critic2、Target Critic1 和 Target Critic2 網路是核心模組,它們分別用於輸出動作、評估狀態-動作對的價值,並透過目標網路進行穩定的更新。
3.1 Actor 網路
Actor 網路用於在給定狀態下輸出一個高斯分佈的均值和標準差(即策略)。它是透過神經網路近似的隨機策略。用於選擇動作。
import torch
import torch.nn as nn
class ActorNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(ActorNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.mean_layer = nn.Linear(256, action_dim) # 輸出動作的均值
self.log_std_layer = nn.Linear(256, action_dim) # 輸出動作的log標準差
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
mean = self.mean_layer(x) # 輸出動作均值
log_std = self.log_std_layer(x) # 輸出 log 標準差
log_std = torch.clamp(log_std, min=-20, max=2) # 限制標準差範圍
return mean, log_std
def sample(self, state):
mean, log_std = self.forward(state)
std = torch.exp(log_std) # 將 log 標準差轉為標準差
normal = torch.distributions.Normal(mean, std)
action = normal.rsample() # 透過重引數化技巧進行取樣
log_prob = normal.log_prob(action).sum(-1) # 計算 log 機率
return action, log_prob
3.2 Critic1 和 Critic2 網路
Critic 網路用於計算狀態-動作對的 Q 值,SAC 使用兩個 Critic 網路(Critic1 和 Critic2)來緩解 Q 值的過估計問題。
class CriticNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(CriticNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.q_value_layer = nn.Linear(256, 1) # 輸出 Q 值
def forward(self, state, action):
x = torch.cat([state, action], dim=-1) # 將 state 和 action 作為輸入
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
q_value = self.q_value_layer(x) # 輸出 Q 值
return q_value
3.3 Target Critic1 和 Target Critic2 網路
Target Critic 網路的結構與 Critic 網路相同,用於穩定 Q 值更新。它們透過軟更新(即在每次訓練後慢慢接近 Critic 網路的引數)來保持訓練的穩定性。
class TargetCriticNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(TargetCriticNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.q_value_layer = nn.Linear(256, 1) # 輸出 Q 值
def forward(self, state, action):
x = torch.cat([state, action], dim=-1) # 將 state 和 action 作為輸入
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
q_value = self.q_value_layer(x) # 輸出 Q 值
return q_value
3.4 軟更新模組
在 SAC 中,目標網路會透過軟更新逐漸逼近 Critic 網路的引數。每次更新後,目標網路引數會按照 ττ 的比例向 Critic 網路的引數靠攏。
def soft_update(critic, target_critic, tau=0.005):
for param, target_param in zip(critic.parameters(), target_critic.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
3.5 總結
- 初始化網路和引數:
- Actor 網路:用於選擇動作。
- Critic 1 和 Critic 2 網路:用於估計 Q 值。
- Target Critic 1 和 Target Critic 2:與 Critic 網路架構相同,用於生成更穩定的目標 Q 值。
- 目標 Q 值計算:
- 使用目標網路計算下一狀態下的 Q 值。
- 取兩個 Q 網路輸出的最小值,防止 Q 值的過估計。
- 引入熵正則項,計算公式:$$y=r+\gamma\cdot\min(Q_1,Q_2)-\alpha\cdot\log\pi(a|s)$$
- 更新 Critic 網路:
- 最小化目標 Q 值與當前 Q 值的均方誤差 (MSE)。
- 更新 Actor 網路:
- 最大化目標損失:$$L=\alpha\cdot\log\pi(a|s)-Q_1(s,\pi(s))$$,即在保證探索的情況下選擇高價值動作。
- 軟更新目標網路:
- 軟更新目標 Q 網路引數,使得目標網路引數緩慢向當前網路靠近,避免振盪。