其實KL散度在這個遊戲裡的作用不大,遊戲的action比較簡單,不像LM裡的action是一個很大的向量,可以直接用surr1,最大化surr1,實驗測試確實是這樣,而且KL的係數不能給太大,否則懲罰力度太大,action model 和ref model產生的action其實分佈的差距並不太大
import gym import torch import torch.nn as nn import torch.optim as optim import numpy as np import pygame import sys from collections import deque # 定義策略網路 class PolicyNetwork(nn.Module): def __init__(self): super(PolicyNetwork, self).__init__() self.fc = nn.Sequential( nn.Linear(4, 2), nn.Tanh(), nn.Linear(2, 2), # CartPole的動作空間為2 nn.Softmax(dim=-1) ) def forward(self, x): return self.fc(x) # 定義值網路 class ValueNetwork(nn.Module): def __init__(self): super(ValueNetwork, self).__init__() self.fc = nn.Sequential( nn.Linear(4, 2), nn.Tanh(), nn.Linear(2, 1) ) def forward(self, x): return self.fc(x) # 經驗回放緩衝區 class RolloutBuffer: def __init__(self): self.states = [] self.actions = [] self.rewards = [] self.dones = [] self.log_probs = [] def store(self, state, action, reward, done, log_prob): self.states.append(state) self.actions.append(action) self.rewards.append(reward) self.dones.append(done) self.log_probs.append(log_prob) def clear(self): self.states = [] self.actions = [] self.rewards = [] self.dones = [] self.log_probs = [] def get_batch(self): return ( torch.tensor(self.states, dtype=torch.float), torch.tensor(self.actions, dtype=torch.long), torch.tensor(self.rewards, dtype=torch.float), torch.tensor(self.dones, dtype=torch.bool), torch.tensor(self.log_probs, dtype=torch.float) ) # PPO更新函式 def ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer, epochs=100, gamma=0.99, clip_param=0.2): states, actions, rewards, dones, old_log_probs = buffer.get_batch() returns = [] advantages = [] G = 0 adv = 0 dones = dones.to(torch.int) # print(dones) for reward, done, value in zip(reversed(rewards), reversed(dones), reversed(value_net(states))): if done: G = 0 adv = 0 G = reward + gamma * G #蒙特卡洛回溯G值 delta = reward + gamma * value.item() * (1 - done) - value.item() #TD差分 # adv = delta + gamma * 0.95 * adv * (1 - done) # adv = delta + adv*(1-done) returns.insert(0, G) advantages.insert(0, adv) returns = torch.tensor(returns, dtype=torch.float) #價值 advantages = torch.tensor(advantages, dtype=torch.float) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) #add baseline for _ in range(epochs): action_probs = policy_net(states) dist = torch.distributions.Categorical(action_probs) new_log_probs = dist.log_prob(actions) ratio = (new_log_probs - old_log_probs).exp() KL = new_log_probs.exp()*(new_log_probs - old_log_probs).mean() #KL散度 p*log(p/p') #下面三行是核心 surr1 = ratio * advantages PPO1,PPO2 = True,False # print(surr1,KL*500) if PPO1 == True: actor_loss = -(surr1 - KL).mean() if PPO2 == True: surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages actor_loss = -torch.min(surr1, surr2).mean() optimizer_policy.zero_grad() actor_loss.backward() optimizer_policy.step() value_loss = (returns - value_net(states)).pow(2).mean() optimizer_value.zero_grad() value_loss.backward() optimizer_value.step() # 初始化環境和模型 env = gym.make('CartPole-v1') policy_net = PolicyNetwork() value_net = ValueNetwork() optimizer_policy = optim.Adam(policy_net.parameters(), lr=3e-4) optimizer_value = optim.Adam(value_net.parameters(), lr=1e-3) buffer = RolloutBuffer() # Pygame初始化 pygame.init() screen = pygame.display.set_mode((600, 400)) clock = pygame.time.Clock() draw_on = False # 訓練迴圈 state = env.reset() for episode in range(10000): # 訓練輪次 done = False state = state[0] step= 0 while not done: step+=1 state_tensor = torch.FloatTensor(state).unsqueeze(0) action_probs = policy_net(state_tensor) #舊policy推理資料 dist = torch.distributions.Categorical(action_probs) action = dist.sample() log_prob = dist.log_prob(action) next_state, reward, done, _ ,_ = env.step(action.item()) buffer.store(state, action.item(), reward, done, log_prob) state = next_state # 實時顯示 for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() if draw_on: # 清屏並重新繪製 screen.fill((0, 0, 0)) cart_x = int(state[0] * 100 + 300) # 位置轉換為螢幕座標 pygame.draw.rect(screen, (0, 128, 255), (cart_x, 300, 50, 30)) pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * np.sin(state[2])), 300 - int(50 * np.cos(state[2]))), 5) pygame.display.flip() clock.tick(60) if step >2000: draw_on = True ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer) buffer.clear() state = env.reset() print(f'Episode {episode} completed , reward: {step}.') # 結束訓練 env.close() pygame.quit()
效果: