cartpole遊戲,車上頂著一個自由擺動的杆子,實現杆子的平衡,杆子每次倒向一端車就開始移動讓杆子保持動態直立的狀態,策略函式使用一個兩層的簡單神經網路,輸入狀態有4個,車位置,車速度,杆角度,杆速度,輸出action為左移動或右移動,輸入狀態發現至少要給3個才能穩定一會兒,給2個完全學不明白,給4個能學到很穩定的policy
策略梯度實現程式碼,使用torch實現一個簡單的神經網路
import gym import torch import torch.nn as nn import torch.optim as optim import pygame import sys from collections import deque import numpy as np # 策略網路定義 class PolicyNetwork(nn.Module): def __init__(self): super(PolicyNetwork, self).__init__() self.fc = nn.Sequential( nn.Linear(4, 10), # 4個狀態輸入,128個隱藏單元 nn.Tanh(), nn.Linear(10, 2), # 輸出2個動作的機率 nn.Softmax(dim=-1) ) def forward(self, x): # print(x) 車位置 車速度 杆角度 杆速度 selected_values = x[:, [0,1,2,3]] #只使用車位置和杆角度 return self.fc(selected_values) # 訓練函式 def train(policy_net, optimizer, trajectories): policy_net.zero_grad() loss = 0 print(trajectories[0]) for trajectory in trajectories: # if trajectory["returns"] > 90: # returns = torch.tensor(trajectory["returns"]).float() # else: returns = torch.tensor(trajectory["returns"]).float() - torch.tensor(trajectory["step_mean_reward"]).float() # print(f"獲得獎勵{returns}") log_probs = trajectory["log_prob"] loss += -(log_probs * returns).sum() # 計算策略梯度損失 loss.backward() optimizer.step() return loss.item() # 主函式 def main(): env = gym.make('CartPole-v1') policy_net = PolicyNetwork() optimizer = optim.Adam(policy_net.parameters(), lr=0.01) print(env.action_space) print(env.observation_space) pygame.init() screen = pygame.display.set_mode((600, 400)) clock = pygame.time.Clock() rewards_one_episode= [] for episode in range(10000): state = env.reset() done = False trajectories = [] state = state[0] step = 0 torch.save(policy_net, 'policy_net_full.pth') while not done: state_tensor = torch.tensor(state).float().unsqueeze(0) probs = policy_net(state_tensor) action = torch.distributions.Categorical(probs).sample().item() log_prob = torch.log(probs.squeeze(0)[action]) next_state, reward, done, _,_ = env.step(action) # print(episode) trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob}) state = next_state for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() step +=1 # 繪製環境狀態 if rewards_one_episode and rewards_one_episode[-1] >99: screen.fill((255, 255, 255)) cart_x = int(state[0] * 100 + 300) pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30)) # print(state) pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2) pygame.display.flip() clock.tick(200) print(f"第{episode}回合",f"執行{step}步後掛了") # 為策略梯度計算累積回報 returns = 0 for traj in reversed(trajectories): returns = traj["reward"] + 0.99 * returns traj["returns"] = returns if rewards_one_episode: # print(rewards_one_episode[:10]) traj["step_mean_reward"] = np.mean(rewards_one_episode[-10:]) else: traj["step_mean_reward"] = 0 rewards_one_episode.append(returns) # print(rewards_one_episode[:10]) train(policy_net, optimizer, trajectories) def play(): env = gym.make('CartPole-v1') policy_net = PolicyNetwork() pygame.init() screen = pygame.display.set_mode((600, 400)) clock = pygame.time.Clock() state = env.reset() done = False trajectories = deque() state = state[0] step = 0 policy_net = torch.load('policy_net_full.pth') while not done: state_tensor = torch.tensor(state).float().unsqueeze(0) probs = policy_net(state_tensor) action = torch.distributions.Categorical(probs).sample().item() log_prob = torch.log(probs.squeeze(0)[action]) next_state, reward, done, _,_ = env.step(action) # print(episode) trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob}) state = next_state for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() # 繪製環境狀態 screen.fill((255, 255, 255)) cart_x = int(state[0] * 100 + 300) pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30)) # print(state) pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2) pygame.display.flip() clock.tick(60) step +=1 print(f"執行{step}步後掛了") if __name__ == '__main__': main() #訓練 # play() #推理
執行效果,訓練過程不是很穩定,有時候學很多輪次也學不明白,有時侯只需要幾十次就可以學明白了