一個連續動作空間的SAC的例子

南乡水發表於2024-10-10
"""My SAC continuous demo"""

import argparse
import copy
import gym
import numpy as np
import torch
import torch.nn.functional as F

from torch import nn
from torch.distributions import Normal


def parse_args() -> argparse.Namespace:
    """Parse arguments."""
    parser = argparse.ArgumentParser(description="Training")
    parser.add_argument(
        "--log_path", type=str, help="Model path", default="./training_log/"
    )
    parser.add_argument(
        "--max_buffer_size", type=int, help="Max buffer size", default=100000
    )
    parser.add_argument(
        "--min_buffer_size", type=int, help="Min buffer size", default=50000
    )
    parser.add_argument("--hidden_width", type=int, help="Hidden width", default=256)
    parser.add_argument(
        "--gamma",
        type=float,
        help="gamma",
        default=0.99,
    )
    parser.add_argument("--tau", type=float, help="tau", default=0.005)
    parser.add_argument(
        "--learning_rate", type=float, help="Learning rate", default=1e-3
    )
    parser.add_argument(
        "--max_train_steps", type=int, help="Max training steps", default=100000
    )
    parser.add_argument("--batch_size", type=int, help="Batch size", default=256)
    parser.add_argument(
        "--evaluate_freqency", type=int, help="Evaluate freqency", default=10000
    )
    return parser.parse_args()


class ReplayBuffer:
    """Replay buffer for storing transitions."""

    def __init__(self, state_dim: int, action_dim: int) -> None:
        self.max_size = int(args.max_buffer_size)
        self.count = 0
        self.size = 0
        self.state = np.zeros((self.max_size, state_dim))
        self.action = np.zeros((self.max_size, action_dim))
        self.reward = np.zeros((self.max_size, 1))
        self.next_state = np.zeros((self.max_size, state_dim))
        self.done = np.zeros((self.max_size, 1))

    def store(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        next_state: np.ndarray,
        done: np.ndarray,
    ) -> None:
        """Store a transition in the replay buffer."""
        self.state[self.count] = state
        self.action[self.count] = action
        self.reward[self.count] = reward
        self.next_state[self.count] = next_state
        self.done[self.count] = done
        self.count = (self.count + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size: int) -> tuple:
        """Sample a batch of transitions."""
        index = np.random.choice(self.size, size=batch_size)
        batch_state = torch.tensor(self.state[index], dtype=torch.float)
        batch_action = torch.tensor(self.action[index], dtype=torch.float)
        batch_reward = torch.tensor(self.reward[index], dtype=torch.float)
        batch_next_state = torch.tensor(self.next_state[index], dtype=torch.float)
        batch_done = torch.tensor(self.done[index], dtype=torch.float)
        return batch_state, batch_action, batch_reward, batch_next_state, batch_done


class Actor(nn.Module):
    """Actor network."""

    def __init__(
        self, state_dim: int, action_dim: int, hidden_width: int, max_action: float
    ) -> None:
        super().__init__()
        self.max_action = max_action
        self.in_layer = nn.Sequential(
            nn.Linear(state_dim, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
        )
        self.res_layer = nn.Sequential(
            nn.Linear(hidden_width, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
            nn.Linear(hidden_width, hidden_width),
        )
        self.out_layer = nn.Sequential(
            nn.Linear(hidden_width, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
        )
        self.mean_layer = nn.Sequential(nn.ReLU(), nn.Linear(hidden_width, action_dim))
        self.log_std_layer = nn.Sequential(
            nn.ReLU(inplace=True), nn.Linear(hidden_width, action_dim)
        )

    def forward(self, x: torch.Tensor, deterministic: bool = False) -> tuple:
        """Forward pass."""
        x = self.in_layer(x)
        x = self.out_layer(x + self.res_layer(x))
        mean = self.mean_layer(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, -20, 2)
        std = torch.exp(log_std)
        dist = Normal(mean, std)
        if deterministic:
            action = mean
        else:
            action = dist.rsample()
        log_pi = dist.log_prob(action).sum(dim=1, keepdim=True)
        log_pi -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(
            dim=1, keepdim=True
        )
        action = self.max_action * torch.tanh(action)
        return action, log_pi


class Critic(nn.Module):
    """Critic network."""

    def __init__(self, state_dim: int, action_dim: int, hidden_width: int) -> None:
        super().__init__()
        self.in_layer1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
        )
        self.res_layer1 = nn.Sequential(
            nn.Linear(hidden_width, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
            nn.Linear(hidden_width, hidden_width),
        )
        self.out_layer1 = nn.Sequential(
            nn.ReLU(inplace=True), nn.Linear(hidden_width, 1)
        )
        self.in_layer2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
        )
        self.res_layer2 = nn.Sequential(
            nn.Linear(hidden_width, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
            nn.Linear(hidden_width, hidden_width),
        )
        self.out_layer2 = nn.Sequential(
            nn.ReLU(inplace=True), nn.Linear(hidden_width, 1)
        )

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> tuple:
        """Forward pass."""
        state_action = torch.cat([state, action], 1)
        q1 = self.in_layer1(state_action)
        q1 = self.out_layer1(q1 + self.res_layer1(q1))
        q2 = self.in_layer2(state_action)
        q2 = self.out_layer2(q2 + self.res_layer2(q2))
        return q1, q2


class SACContinuous:
    """Soft Actor-Critic for continuous action space."""

    def __init__(self, state_dim: int, action_dim: int, max_action: float) -> None:
        self.gamma = args.gamma
        self.tau = args.tau
        self.batch_size = args.batch_size
        self.learning_rate = args.learning_rate
        self.hidden_width = args.hidden_width
        self.max_action = max_action
        self.target_entropy = -action_dim
        self.log_alpha = torch.zeros(1, requires_grad=True)
        self.alpha = self.log_alpha.exp()
        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=self.learning_rate)
        self.actor = Actor(state_dim, action_dim, self.hidden_width, max_action)
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=self.learning_rate
        )
        self.critic = Critic(state_dim, action_dim, self.hidden_width)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=self.learning_rate
        )

    def choose_action(
        self, state: np.ndarray, deterministic: bool = False
    ) -> np.ndarray:
        """Choose action."""
        state = torch.unsqueeze(torch.tensor(state, dtype=torch.float), 0)
        action, _ = self.actor(state, deterministic)
        return action.data.numpy().flatten()

    def learn(self, relay_buffer: ReplayBuffer) -> None:
        """Learn."""
        batch_state, batch_action, batch_reward, batch_next_state, batch_done = (
            relay_buffer.sample(self.batch_size)
        )
        with torch.no_grad():
            batch_next_action, log_pi_ = self.actor(batch_next_state)
            target_q1, target_q2 = self.critic_target(
                batch_next_state, batch_next_action
            )
            target_q = batch_reward + self.gamma * (1 - batch_done) * (
                torch.min(target_q1, target_q2) - self.alpha * log_pi_
            )
        current_q1, current_q2 = self.critic(batch_state, batch_action)
        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(
            current_q2, target_q
        )
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        for params in self.critic.parameters():
            params.requires_grad = False
        action, log_pi = self.actor(batch_state)
        q1, q2 = self.critic(batch_state, action)
        q = torch.min(q1, q2)
        actor_loss = (self.alpha * log_pi - q).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        for params in self.critic.parameters():
            params.requires_grad = True
        alpha_loss = -(
            self.log_alpha.exp() * (log_pi + self.target_entropy).detach()
        ).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp()
        for param, target_param in zip(
            self.critic.parameters(), self.critic_target.parameters()
        ):
            target_param.data.copy_(
                self.tau * param.data + (1 - self.tau) * target_param.data
            )


def evaluate_policy(env, agent: SACContinuous) -> float:
    """Evaluate the policy."""
    state = env.reset()[0]
    done = False
    episode_reward = 0
    action_num = 0
    while not done:
        action = agent.choose_action(state, deterministic=True)
        next_statue, reward, done, _, _ = env.step(action)
        episode_reward += reward
        state = next_statue
        action_num += 1
        if action_num >= 1000:
            print("action_num too large.")
            break
        if episode_reward <= -1000:
            print("episode_reward too small.")
            break
    return episode_reward


def training() -> None:
    """My demo training function."""
    env_name = "Pendulum-v1"
    env = gym.make(env_name)
    env_evaluate = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])
    agent = SACContinuous(state_dim, action_dim, max_action)
    replay_buffer = ReplayBuffer(state_dim, action_dim)
    evaluate_num = 0
    total_steps = 0
    while total_steps < args.max_train_steps:
        state = env.reset()[0]
        episode_steps = 0
        done = False
        while not done:
            episode_steps += 1
            action = agent.choose_action(state)
            next_state, reward, done, _, _ = env.step(action)
            replay_buffer.store(state, action, reward, next_state, done)
            state = next_state
            if total_steps >= args.min_buffer_size:
                agent.learn(replay_buffer)
            if (total_steps + 1) % args.evaluate_freqency == 0:
                evaluate_num += 1
                evaluate_reward = evaluate_policy(env_evaluate, agent)
                print(
                    f"evaluate_num: {evaluate_num} \t evaluate_reward: {evaluate_reward}"
                )
            total_steps += 1
            if total_steps >= args.max_train_steps:
                break
    env.close()
    torch.save(agent.actor.state_dict(), f"{args.log_path}/trained_model.pth")


def testing() -> None:
    """My demo testing function."""
    env_name = "Pendulum-v1"
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])
    agent = SACContinuous(state_dim, action_dim, max_action)
    agent.actor.load_state_dict(torch.load(f"{args.log_path}/trained_model.pth"))
    state = env.reset()[0]
    total_rewards = 0
    for _ in range(1000):
        env.render()
        action = agent.choose_action(state)
        new_state, reward, _, _, _ = env.step(action)
        total_rewards += reward
        state = new_state
    env.close()
    print(f"SAC actor scores: {total_rewards}")


if __name__ == "__main__":
    args = parse_args()
    training()
    testing()

  

相關文章