強化學習筆記之【SAC演算法】

这可就有点麻烦了發表於2024-10-11

強化學習筆記之【SAC演算法】


前言:

本文為強化學習筆記第四篇,第一篇講的是Q-learning和DQN,第二篇DDPG,第三篇TD3

TD3比DDPG少了一個target_actor網路,其它地方有點小改動

CSDN主頁:https://blog.csdn.net/rvdgdsva

部落格園主頁:https://www.cnblogs.com/hassle


目錄
  • 強化學習筆記之【SAC演算法】
      • 前言:
      • 一、SAC演算法
      • 二、SAC演算法Latex解釋
      • 三、SAC五大網路和模組
        • 3.1 Actor 網路
        • 3.2 Critic1 和 Critic2 網路
        • 3.3 Target Critic1 和 Target Critic2 網路
        • 3.4 軟更新模組
        • 3.5 總結

強化學習筆記之【SAC演算法】
STAND ALONE COMPLEX = S . A . C

首先,我們需要明確,Q-learning演算法發展成DQN演算法,DQN演算法發展成為DDPG演算法,而DDPG演算法發展成TD3演算法,TD3演算法發展成SAC演算法

Soft Actor-Critic (SAC) 是一種基於策略梯度的深度強化學習演算法,它具有最大化獎勵與最大化熵(探索性)的雙重目標。SAC 透過引入熵正則項,使策略在決策時具有更大的隨機性,從而提高探索能力。

一、SAC演算法

OK,先用虛擬碼讓你們感受一下SAC演算法

# 定義 SAC 超引數
alpha = 0.2               # 熵正則項係數
gamma = 0.99              # 折扣因子
tau = 0.005               # 目標網路軟更新引數
lr = 3e-4                 # 學習率

# 初始化 Actor、Critic、Target Critic 網路和最佳化器
actor = ActorNetwork()                      # 策略網路 π(s)
critic1 = CriticNetwork()                   # 第一個 Q 網路 Q1(s, a)
critic2 = CriticNetwork()                   # 第二個 Q 網路 Q2(s, a)
target_critic1 = CriticNetwork()            # 目標 Q 網路 1
target_critic2 = CriticNetwork()            # 目標 Q 網路 2

# 將目標 Q 網路的引數設定為與 Critic 網路相同
target_critic1.load_state_dict(critic1.state_dict())
target_critic2.load_state_dict(critic2.state_dict())

# 初始化最佳化器
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=lr)
critic1_optimizer = torch.optim.Adam(critic1.parameters(), lr=lr)
critic2_optimizer = torch.optim.Adam(critic2.parameters(), lr=lr)

# 經驗回放池(Replay Buffer)
replay_buffer = ReplayBuffer()

# SAC 訓練迴圈
for each iteration:
    # Step 1: 從 Replay Buffer 中取樣一個批次 (state, action, reward, next_state)
    batch = replay_buffer.sample()
    state, action, reward, next_state, done = batch

    # Step 2: 計算目標 Q 值 (y)
    with torch.no_grad():
        # 從 Actor 網路中獲取 next_state 的下一個動作
        next_action, next_log_prob = actor.sample(next_state)
        
        # 目標 Q 值的計算:使用目標 Q 網路的最小值 + 熵項
        target_q1_value = target_critic1(next_state, next_action)
        target_q2_value = target_critic2(next_state, next_action)
        min_target_q_value = torch.min(target_q1_value, target_q2_value)

        # 目標 Q 值 y = r + γ * (最小目標 Q 值 - α * next_log_prob)
        target_q_value = reward + gamma * (1 - done) * (min_target_q_value - alpha * next_log_prob)

    # Step 3: 更新 Critic 網路
    # Critic 1 損失
    current_q1_value = critic1(state, action)
    critic1_loss = F.mse_loss(current_q1_value, target_q_value)

    # Critic 2 損失
    current_q2_value = critic2(state, action)
    critic2_loss = F.mse_loss(current_q2_value, target_q_value)

    # 反向傳播並更新 Critic 網路引數
    critic1_optimizer.zero_grad()
    critic1_loss.backward()
    critic1_optimizer.step()

    critic2_optimizer.zero_grad()
    critic2_loss.backward()
    critic2_optimizer.step()

    # Step 4: 更新 Actor 網路
    # 透過 Actor 網路生成新的動作及其 log 機率
    new_action, log_prob = actor.sample(state)

    # 計算 Actor 的目標損失:L = α * log_prob - Q1(s, π(s))
    q1_value = critic1(state, new_action)
    actor_loss = (alpha * log_prob - q1_value).mean()

    # 反向傳播並更新 Actor 網路引數
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

    # Step 5: 軟更新目標 Q 網路引數
    with torch.no_grad():
        for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

二、SAC演算法Latex解釋

1、初始化 Actor、Critic1、Critic2、TargetCritic1 、TargetCritic2 網路
2、Buffer中取樣 (state, action, reward, next_state)

3、Actor 輸入 next_state 對應輸出 next_action 和 next_log_prob
4、Actor 輸入 state 對應輸出 new_action 和 log_prob
5、Critic1 和 Critic2 分別輸入next_state 和 next_action 取其中較小輸出經熵正則計算得 target_q_value

6、使用 MSE_loss(Critic1(state, action), target_q_value) 更新 Critic1
7、使用 MSE_loss(Critic2(state, action), target_q_value) 更新 Critic2
8、使用 (alpha * log_prob - critic1(state, new_action)).mean() 更新 Actor

強化學習筆記之【SAC演算法】

三、SAC五大網路和模組

SAC 演算法 中,Actor、Critic1、Critic2、Target Critic1 和 Target Critic2 網路是核心模組,它們分別用於輸出動作、評估狀態-動作對的價值,並透過目標網路進行穩定的更新。

3.1 Actor 網路

Actor 網路用於在給定狀態下輸出一個高斯分佈的均值和標準差(即策略)。它是透過神經網路近似的隨機策略。用於選擇動作。

import torch
import torch.nn as nn

class ActorNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.mean_layer = nn.Linear(256, action_dim)  # 輸出動作的均值
        self.log_std_layer = nn.Linear(256, action_dim)  # 輸出動作的log標準差

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        mean = self.mean_layer(x)  # 輸出動作均值
        log_std = self.log_std_layer(x)  # 輸出 log 標準差
        log_std = torch.clamp(log_std, min=-20, max=2)  # 限制標準差範圍
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = torch.exp(log_std)  # 將 log 標準差轉為標準差
        normal = torch.distributions.Normal(mean, std)
        action = normal.rsample()  # 透過重引數化技巧進行取樣
        log_prob = normal.log_prob(action).sum(-1)  # 計算 log 機率
        return action, log_prob


3.2 Critic1 和 Critic2 網路

Critic 網路用於計算狀態-動作對的 Q 值,SAC 使用兩個 Critic 網路(Critic1 和 Critic2)來緩解 Q 值的過估計問題。

class CriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.q_value_layer = nn.Linear(256, 1)  # 輸出 Q 值

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)  # 將 state 和 action 作為輸入
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        q_value = self.q_value_layer(x)  # 輸出 Q 值
        return q_value


3.3 Target Critic1 和 Target Critic2 網路

Target Critic 網路的結構與 Critic 網路相同,用於穩定 Q 值更新。它們透過軟更新(即在每次訓練後慢慢接近 Critic 網路的引數)來保持訓練的穩定性。

class TargetCriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(TargetCriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.q_value_layer = nn.Linear(256, 1)  # 輸出 Q 值

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)  # 將 state 和 action 作為輸入
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        q_value = self.q_value_layer(x)  # 輸出 Q 值
        return q_value

3.4 軟更新模組

在 SAC 中,目標網路會透過軟更新逐漸逼近 Critic 網路的引數。每次更新後,目標網路引數會按照 ττ 的比例向 Critic 網路的引數靠攏。

def soft_update(critic, target_critic, tau=0.005):
    for param, target_param in zip(critic.parameters(), target_critic.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

3.5 總結

  1. 初始化網路和引數:
    • Actor 網路:用於選擇動作。
    • Critic 1 和 Critic 2 網路:用於估計 Q 值。
    • Target Critic 1 和 Target Critic 2:與 Critic 網路架構相同,用於生成更穩定的目標 Q 值。
  2. 目標 Q 值計算:
    • 使用目標網路計算下一狀態下的 Q 值。
    • 取兩個 Q 網路輸出的最小值,防止 Q 值的過估計。
    • 引入熵正則項,計算公式:$$y=r+\gamma\cdot\min(Q_1,Q_2)-\alpha\cdot\log\pi(a|s)$$
  3. 更新 Critic 網路:
    • 最小化目標 Q 值與當前 Q 值的均方誤差 (MSE)。
  4. 更新 Actor 網路:
    • 最大化目標損失:$$L=\alpha\cdot\log\pi(a|s)-Q_1(s,\pi(s))$$,即在保證探索的情況下選擇高價值動作。
  5. 軟更新目標網路:
    • 軟更新目標 Q 網路引數,使得目標網路引數緩慢向當前網路靠近,避免振盪。

相關文章