一文教你在MindSpore中實現A2C演算法訓練

华为云开发者联盟發表於2024-06-07

本文分享自華為雲社群《MindSpore A2C 強化學習》,作者:irrational。

Advantage Actor-Critic (A2C)演算法是一個強化學習演算法,它結合了策略梯度(Actor)和價值函式(Critic)的方法。A2C演算法在許多強化學習任務中表現優越,因為它能夠利用價值函式來減少策略梯度的方差,同時直接最佳化策略。

A2C演算法的核心思想

  • Actor:根據當前策略選擇動作。
  • Critic:評估一個狀態-動作對的值(通常是使用狀態值函式或動作值函式)。
  • 優勢函式(Advantage Function):用來衡量某個動作相對於平均水平的好壞,通常定義為A(s,a)=Q(s,a)−V(s)。

A2C演算法的虛擬碼

以下是A2C演算法的虛擬碼:

Initialize policy network (actor) π with parameters θ
Initialize value network (critic) V with parameters w
Initialize learning rates α_θ for policy network and α_w for value network

for each episode do
    Initialize state s
    while state s is not terminal do
        # Actor: select action a according to the current policy π(a|s; θ)
        a = select_action(s, θ)
        
        # Execute action a in the environment, observe reward r and next state s'
        r, s' = environment.step(a)
        
        # Critic: compute the value of the current state V(s; w)
        V_s = V(s, w)
        
        # Critic: compute the value of the next state V(s'; w)
        V_s_prime = V(s', w)
        
        # Compute the TD error (δ)
        δ = r + γ * V_s_prime - V_s
        
        # Critic: update the value network parameters w
        w = w + α_w * δ * ∇_w V(s; w)
        
        # Compute the advantage function A(s, a)
        A = δ
        
        # Actor: update the policy network parameters θ
        θ = θ + α_θ * A * ∇_θ log π(a|s; θ)
        
        # Move to the next state
        s = s'
    end while
end for

解釋

  1. 初始化:初始化策略網路(Actor)和價值網路(Critic)的引數,以及它們的學習率。
  2. 迴圈每個Episode:在每個Episode開始時,初始化狀態。
  3. 選擇動作:根據當前策略從Actor中選擇動作。
  4. 執行動作:在環境中執行動作,並觀察獎勵和下一個狀態。
  5. 計算狀態值:用Critic評估當前狀態和下一個狀態的值。
  6. 計算TD誤差:計算時序差分誤差(Temporal Difference Error),它是當前獎勵加上下一個狀態的折扣值與當前狀態值的差。
  7. 更新Critic:根據TD誤差更新價值網路的引數。
  8. 計算優勢函式:使用TD誤差計算優勢函式。
  9. 更新Actor:根據優勢函式更新策略網路的引數。
  10. 更新狀態:移動到下一個狀態,重複上述步驟,直到Episode結束。

這個虛擬碼展示了A2C演算法的核心步驟,實際實現中可能會有更多細節,如使用折扣因子γ、多個並行環境等。

程式碼如下:

import argparse

from mindspore import context
from mindspore import dtype as mstype
from mindspore.communication import init

from mindspore_rl.algorithm.a2c import config
from mindspore_rl.algorithm.a2c.a2c_session import A2CSession
from mindspore_rl.algorithm.a2c.a2c_trainer import A2CTrainer

parser = argparse.ArgumentParser(description="MindSpore Reinforcement A2C")
parser.add_argument("--episode", type=int, default=10000, help="total episode numbers.")
parser.add_argument(
    "--device_target",
    type=str,
    default="CPU",
    choices=["CPU", "GPU", "Ascend", "Auto"],
    help="Choose a devioptions.device_targece to run the ac example(Default: Auto).",
)
parser.add_argument(
    "--precision_mode",
    type=str,
    default="fp32",
    choices=["fp32", "fp16"],
    help="Precision mode",
)
parser.add_argument(
    "--env_yaml",
    type=str,
    default="../env_yaml/CartPole-v0.yaml",
    help="Choose an environment yaml to update the a2c example(Default: CartPole-v0.yaml).",
)
parser.add_argument(
    "--algo_yaml",
    type=str,
    default=None,
    help="Choose an algo yaml to update the a2c example(Default: None).",
)
parser.add_argument(
    "--enable_distribute",
    type=bool,
    default=False,
    help="Train in distribute mode (Default: False).",
)
parser.add_argument(
    "--worker_num",
    type=int,
    default=2,
    help="Worker num (Default: 2).",
)
options, _ = parser.parse_known_args()

首先初始化引數,然後我這裡用cpu執行:options.device_targe = “CPU”

episode=options.episode
"""Train a2c"""
if options.device_target != "Auto":
    context.set_context(device_target=options.device_target)
if context.get_context("device_target") in ["CPU", "GPU"]:
    context.set_context(enable_graph_kernel=True)
context.set_context(mode=context.GRAPH_MODE)
compute_type = (
    mstype.float32 if options.precision_mode == "fp32" else mstype.float16
)
config.algorithm_config["policy_and_network"]["params"][
    "compute_type"
] = compute_type
if compute_type == mstype.float16 and options.device_target != "Ascend":
    raise ValueError("Fp16 mode is supported by Ascend backend.")
is_distribte = options.enable_distribute
if is_distribte:
    init()
    context.set_context(enable_graph_kernel=False)
    config.deploy_config["worker_num"] = options.worker_num
a2c_session = A2CSession(options.env_yaml, options.algo_yaml, is_distribte)

設定上下文管理器

import sys
import time
from io import StringIO

class RealTimeCaptureAndDisplayOutput(object):
    def __init__(self):
        self._original_stdout = sys.stdout
        self._original_stderr = sys.stderr
        self.captured_output = StringIO()

    def write(self, text):
        self._original_stdout.write(text)  # 實時列印
        self.captured_output.write(text)   # 儲存到緩衝區

    def flush(self):
        self._original_stdout.flush()
        self.captured_output.flush()

    def __enter__(self):
        sys.stdout = self
        sys.stderr = self
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout = self._original_stdout
        sys.stderr = self._original_stderr
episode=10
# dqn_session.run(class_type=DQNTrainer, episode=episode)
with RealTimeCaptureAndDisplayOutput() as captured_new:
    a2c_session.run(class_type=A2CTrainer, episode=episode)
import re
import matplotlib.pyplot as plt

# 原始輸出
raw_output = captured_new.captured_output.getvalue()

# 使用正規表示式從輸出中提取loss和rewards
loss_pattern = r"loss=(\d+\.\d+)"
reward_pattern = r"running_reward=(\d+\.\d+)"
loss_values = [float(match.group(1)) for match in re.finditer(loss_pattern, raw_output)]
reward_values = [float(match.group(1)) for match in re.finditer(reward_pattern, raw_output)]

# 繪製loss曲線
plt.plot(loss_values, label='Loss')
plt.xlabel('Episode')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.show()

# 繪製reward曲線
plt.plot(reward_values, label='Rewards')
plt.xlabel('Episode')
plt.ylabel('Rewards')
plt.title('Rewards Curve')
plt.legend()
plt.show()

展示結果:
image.png

image.png

下面我將詳細解釋你提供的 MindSpore A2C 演算法訓練配置引數的含義:

Actor 配置

'actor': {
  'number': 1,
  'type': mindspore_rl.algorithm.a2c.a2c.A2CActor,
  'params': {
    'collect_environment': PyFuncWrapper<
       (_envs): GymEnvironment<>
     >,
   'eval_environment': PyFuncWrapper<
     (_envs): GymEnvironment<>
     >,
   'replay_buffer': None,
   'a2c_net': ActorCriticNet<
     (common): Dense<input_channels=4, output_channels=128, has_bias=True>
     (actor): Dense<input_channels=128, output_channels=2, has_bias=True>
     (critic): Dense<input_channels=128, output_channels=1, has_bias=True>
     (relu): LeakyReLU<>
     >},
  'policies': [],
  'networks': ['a2c_net']
}
  • number: Actor 的例項數量,這裡設定為1,表示使用一個 Actor 例項。
  • type: Actor 的型別,這裡使用 mindspore_rl.algorithm.a2c.a2c.A2CActor
  • params: Actor 的引數配置。
    • collect_environmenteval_environment: 使用 PyFuncWrapper 包裝的 GymEnvironment,用於資料收集和評估環境。
    • replay_buffer: 設定為 None,表示不使用經驗回放緩衝區。
    • a2c_net: Actor-Critic 網路,包含一個公共層、一個 Actor 層和一個 Critic 層,以及一個 Leaky ReLU 啟用函式。
  • policiesnetworks: Actor 關聯的策略和網路,這裡主要是 a2c_net

Learner 配置

'learner': {
  'number': 1,
  'type': mindspore_rl.algorithm.a2c.a2c.A2CLearner,
  'params': {
    'gamma': 0.99,
    'state_space_dim': 4,
    'action_space_dim': 2,
    'a2c_net': ActorCriticNet<
      (common): Dense<input_channels=4, output_channels=128, has_bias=True>
      (actor): Dense<input_channels=128, output_channels=2, has_bias=True>
      (critic): Dense<input_channels=128, output_channels=1, has_bias=True>
      (relu): LeakyReLU<>
    >,
    'a2c_net_train': TrainOneStepCell<
      (network): Loss<
        (a2c_net): ActorCriticNet<
          (common): Dense<input_channels=4, output_channels=128, has_bias=True>
          (actor): Dense<input_channels=128, output_channels=2, has_bias=True>
          (critic): Dense<input_channels=128, output_channels=1, has_bias=True>
          (relu): LeakyReLU<>
        >
        (smoothl1_loss): SmoothL1Loss<>
      >
      (optimizer): Adam<>
      (grad_reducer): Identity<>
    >
  },
  'networks': ['a2c_net_train', 'a2c_net']
}
  • number: Learner 的例項數量,這裡設定為1,表示使用一個 Learner 例項。
  • type: Learner 的型別,這裡使用 mindspore_rl.algorithm.a2c.a2c.A2CLearner
  • params: Learner 的引數配置。
    • gamma: 折扣因子,用於未來獎勵的折扣計算。
    • state_space_dim: 狀態空間的維度,這裡為4。
    • action_space_dim: 動作空間的維度,這裡為2。
    • a2c_net: Actor-Critic 網路定義,與 Actor 中相同。
    • a2c_net_train: 用於訓練的網路,包含損失函式(SmoothL1Loss)、最佳化器(Adam)和梯度縮減器(Identity)。
  • networks: Learner 關聯的網路,包括 a2c_net_traina2c_net

Policy and Network 配置

'policy_and_network': {
  'type': mindspore_rl.algorithm.a2c.a2c.A2CPolicyAndNetwork,
  'params': {
    'lr': 0.01,
    'state_space_dim': 4,
    'action_space_dim': 2,
    'hidden_size': 128,
    'gamma': 0.99,
    'compute_type': mindspore.float32,
    'environment_config': {
      'id': 'CartPole-v0',
      'entry_point': 'gym.envs.classic_control:CartPoleEnv',
      'reward_threshold': 195.0,
      'nondeterministic': False,
      'max_episode_steps': 200,
      '_kwargs': {},
      '_env_name': 'CartPole'
    }
  }
}
  • type: 策略和網路的型別,這裡使用 mindspore_rl.algorithm.a2c.a2c.A2CPolicyAndNetwork
  • params: 策略和網路的引數配置。
    • lr: 學習率,這裡為0.01。
    • state_space_dimaction_space_dim: 狀態和動作空間的維度。
    • hidden_size: 隱藏層的大小,這裡為128。
    • gamma: 折扣因子。
    • compute_type: 計算型別,這裡為 mindspore.float32
    • environment_config: 環境配置,包括環境 ID、入口、獎勵閾值、最大步數等。

Collect Environment 配置

'collect_environment': {
  'number': 1,
  'type': mindspore_rl.environment.gym_environment.GymEnvironment,
  'wrappers': [mindspore_rl.environment.pyfunc_wrapper.PyFuncWrapper],
  'params': {
    'GymEnvironment': {
      'name': 'CartPole-v0',
      'seed': 42
    },
    'name': 'CartPole-v0'
  }
}
  • number: 環境例項數量,這裡為1。
  • type: 環境的型別,這裡使用 mindspore_rl.environment.gym_environment.GymEnvironment
  • wrappers: 環境使用的包裝器,這裡是 PyFuncWrapper
  • params: 環境的引數配置,包括環境名稱 CartPole-v0 和隨機種子 42

Eval Environment 配置

'eval_environment': {
  'number': 1,
  'type': mindspore_rl.environment.gym_environment.GymEnvironment,
  'wrappers': [mindspore_rl.environment.pyfunc_wrapper.PyFuncWrapper],
  'params': {
    'GymEnvironment': {
      'name': 'CartPole-v0',
      'seed': 42
    },
    'name': 'CartPole-v0'
  }
}
  • 配置與 collect_environment 類似,用於評估模型效能。

總結一下,這些配置定義了 Actor-Critic 演算法在 MindSpore 框架中的具體實現,包括 Actor 和 Learner 的設定、策略和網路的引數,以及訓練和評估環境的配置。這個還是比較基礎的。

點選關注,第一時間瞭解華為雲新鮮技術~

相關文章