"""My SAC continuous demo""" import argparse import copy import gym import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.distributions import Normal def parse_args() -> argparse.Namespace: """Parse arguments.""" parser = argparse.ArgumentParser(description="Training") parser.add_argument( "--log_path", type=str, help="Model path", default="./training_log/" ) parser.add_argument( "--max_buffer_size", type=int, help="Max buffer size", default=100000 ) parser.add_argument( "--min_buffer_size", type=int, help="Min buffer size", default=50000 ) parser.add_argument("--hidden_width", type=int, help="Hidden width", default=256) parser.add_argument( "--gamma", type=float, help="gamma", default=0.99, ) parser.add_argument("--tau", type=float, help="tau", default=0.005) parser.add_argument( "--learning_rate", type=float, help="Learning rate", default=1e-3 ) parser.add_argument( "--max_train_steps", type=int, help="Max training steps", default=100000 ) parser.add_argument("--batch_size", type=int, help="Batch size", default=256) parser.add_argument( "--evaluate_freqency", type=int, help="Evaluate freqency", default=10000 ) return parser.parse_args() class ReplayBuffer: """Replay buffer for storing transitions.""" def __init__(self, state_dim: int, action_dim: int) -> None: self.max_size = int(args.max_buffer_size) self.count = 0 self.size = 0 self.state = np.zeros((self.max_size, state_dim)) self.action = np.zeros((self.max_size, action_dim)) self.reward = np.zeros((self.max_size, 1)) self.next_state = np.zeros((self.max_size, state_dim)) self.done = np.zeros((self.max_size, 1)) def store( self, state: np.ndarray, action: np.ndarray, reward: np.ndarray, next_state: np.ndarray, done: np.ndarray, ) -> None: """Store a transition in the replay buffer.""" self.state[self.count] = state self.action[self.count] = action self.reward[self.count] = reward self.next_state[self.count] = next_state self.done[self.count] = done self.count = (self.count + 1) % self.max_size self.size = min(self.size + 1, self.max_size) def sample(self, batch_size: int) -> tuple: """Sample a batch of transitions.""" index = np.random.choice(self.size, size=batch_size) batch_state = torch.tensor(self.state[index], dtype=torch.float) batch_action = torch.tensor(self.action[index], dtype=torch.float) batch_reward = torch.tensor(self.reward[index], dtype=torch.float) batch_next_state = torch.tensor(self.next_state[index], dtype=torch.float) batch_done = torch.tensor(self.done[index], dtype=torch.float) return batch_state, batch_action, batch_reward, batch_next_state, batch_done class Actor(nn.Module): """Actor network.""" def __init__( self, state_dim: int, action_dim: int, hidden_width: int, max_action: float ) -> None: super().__init__() self.max_action = max_action self.in_layer = nn.Sequential( nn.Linear(state_dim, hidden_width), nn.ReLU(inplace=True), nn.LayerNorm(hidden_width), ) self.res_layer = nn.Sequential( nn.Linear(hidden_width, hidden_width), nn.ReLU(inplace=True), nn.LayerNorm(hidden_width), nn.Linear(hidden_width, hidden_width), ) self.out_layer = nn.Sequential( nn.Linear(hidden_width, hidden_width), nn.ReLU(inplace=True), nn.LayerNorm(hidden_width), ) self.mean_layer = nn.Sequential(nn.ReLU(), nn.Linear(hidden_width, action_dim)) self.log_std_layer = nn.Sequential( nn.ReLU(inplace=True), nn.Linear(hidden_width, action_dim) ) def forward(self, x: torch.Tensor, deterministic: bool = False) -> tuple: """Forward pass.""" x = self.in_layer(x) x = self.out_layer(x + self.res_layer(x)) mean = self.mean_layer(x) log_std = self.log_std_layer(x) log_std = torch.clamp(log_std, -20, 2) std = torch.exp(log_std) dist = Normal(mean, std) if deterministic: action = mean else: action = dist.rsample() log_pi = dist.log_prob(action).sum(dim=1, keepdim=True) log_pi -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum( dim=1, keepdim=True ) action = self.max_action * torch.tanh(action) return action, log_pi class Critic(nn.Module): """Critic network.""" def __init__(self, state_dim: int, action_dim: int, hidden_width: int) -> None: super().__init__() self.in_layer1 = nn.Sequential( nn.Linear(state_dim + action_dim, hidden_width), nn.ReLU(inplace=True), nn.LayerNorm(hidden_width), ) self.res_layer1 = nn.Sequential( nn.Linear(hidden_width, hidden_width), nn.ReLU(inplace=True), nn.LayerNorm(hidden_width), nn.Linear(hidden_width, hidden_width), ) self.out_layer1 = nn.Sequential( nn.ReLU(inplace=True), nn.Linear(hidden_width, 1) ) self.in_layer2 = nn.Sequential( nn.Linear(state_dim + action_dim, hidden_width), nn.ReLU(inplace=True), nn.LayerNorm(hidden_width), ) self.res_layer2 = nn.Sequential( nn.Linear(hidden_width, hidden_width), nn.ReLU(inplace=True), nn.LayerNorm(hidden_width), nn.Linear(hidden_width, hidden_width), ) self.out_layer2 = nn.Sequential( nn.ReLU(inplace=True), nn.Linear(hidden_width, 1) ) def forward(self, state: torch.Tensor, action: torch.Tensor) -> tuple: """Forward pass.""" state_action = torch.cat([state, action], 1) q1 = self.in_layer1(state_action) q1 = self.out_layer1(q1 + self.res_layer1(q1)) q2 = self.in_layer2(state_action) q2 = self.out_layer2(q2 + self.res_layer2(q2)) return q1, q2 class SACContinuous: """Soft Actor-Critic for continuous action space.""" def __init__(self, state_dim: int, action_dim: int, max_action: float) -> None: self.gamma = args.gamma self.tau = args.tau self.batch_size = args.batch_size self.learning_rate = args.learning_rate self.hidden_width = args.hidden_width self.max_action = max_action self.target_entropy = -action_dim self.log_alpha = torch.zeros(1, requires_grad=True) self.alpha = self.log_alpha.exp() self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=self.learning_rate) self.actor = Actor(state_dim, action_dim, self.hidden_width, max_action) self.actor_optimizer = torch.optim.Adam( self.actor.parameters(), lr=self.learning_rate ) self.critic = Critic(state_dim, action_dim, self.hidden_width) self.critic_target = copy.deepcopy(self.critic) self.critic_optimizer = torch.optim.Adam( self.critic.parameters(), lr=self.learning_rate ) def choose_action( self, state: np.ndarray, deterministic: bool = False ) -> np.ndarray: """Choose action.""" state = torch.unsqueeze(torch.tensor(state, dtype=torch.float), 0) action, _ = self.actor(state, deterministic) return action.data.numpy().flatten() def learn(self, relay_buffer: ReplayBuffer) -> None: """Learn.""" batch_state, batch_action, batch_reward, batch_next_state, batch_done = ( relay_buffer.sample(self.batch_size) ) with torch.no_grad(): batch_next_action, log_pi_ = self.actor(batch_next_state) target_q1, target_q2 = self.critic_target( batch_next_state, batch_next_action ) target_q = batch_reward + self.gamma * (1 - batch_done) * ( torch.min(target_q1, target_q2) - self.alpha * log_pi_ ) current_q1, current_q2 = self.critic(batch_state, batch_action) critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss( current_q2, target_q ) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() for params in self.critic.parameters(): params.requires_grad = False action, log_pi = self.actor(batch_state) q1, q2 = self.critic(batch_state, action) q = torch.min(q1, q2) actor_loss = (self.alpha * log_pi - q).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() for params in self.critic.parameters(): params.requires_grad = True alpha_loss = -( self.log_alpha.exp() * (log_pi + self.target_entropy).detach() ).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.alpha = self.log_alpha.exp() for param, target_param in zip( self.critic.parameters(), self.critic_target.parameters() ): target_param.data.copy_( self.tau * param.data + (1 - self.tau) * target_param.data ) def evaluate_policy(env, agent: SACContinuous) -> float: """Evaluate the policy.""" state = env.reset()[0] done = False episode_reward = 0 action_num = 0 while not done: action = agent.choose_action(state, deterministic=True) next_statue, reward, done, _, _ = env.step(action) episode_reward += reward state = next_statue action_num += 1 if action_num >= 1000: print("action_num too large.") break if episode_reward <= -1000: print("episode_reward too small.") break return episode_reward def training() -> None: """My demo training function.""" env_name = "Pendulum-v1" env = gym.make(env_name) env_evaluate = gym.make(env_name) state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] max_action = float(env.action_space.high[0]) agent = SACContinuous(state_dim, action_dim, max_action) replay_buffer = ReplayBuffer(state_dim, action_dim) evaluate_num = 0 total_steps = 0 while total_steps < args.max_train_steps: state = env.reset()[0] episode_steps = 0 done = False while not done: episode_steps += 1 action = agent.choose_action(state) next_state, reward, done, _, _ = env.step(action) replay_buffer.store(state, action, reward, next_state, done) state = next_state if total_steps >= args.min_buffer_size: agent.learn(replay_buffer) if (total_steps + 1) % args.evaluate_freqency == 0: evaluate_num += 1 evaluate_reward = evaluate_policy(env_evaluate, agent) print( f"evaluate_num: {evaluate_num} \t evaluate_reward: {evaluate_reward}" ) total_steps += 1 if total_steps >= args.max_train_steps: break env.close() torch.save(agent.actor.state_dict(), f"{args.log_path}/trained_model.pth") def testing() -> None: """My demo testing function.""" env_name = "Pendulum-v1" env = gym.make(env_name) state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] max_action = float(env.action_space.high[0]) agent = SACContinuous(state_dim, action_dim, max_action) agent.actor.load_state_dict(torch.load(f"{args.log_path}/trained_model.pth")) state = env.reset()[0] total_rewards = 0 for _ in range(1000): env.render() action = agent.choose_action(state) new_state, reward, _, _, _ = env.step(action) total_rewards += reward state = new_state env.close() print(f"SAC actor scores: {total_rewards}") if __name__ == "__main__": args = parse_args() training() testing()