強化學習實戰 | 表格型Q-Learning玩井子棋(三)優化,優化

埠默笙聲聲聲脈發表於2021-12-10

強化學習實戰 | 表格型Q-Learning玩井字棋(二)開始訓練!中,我們讓agent“簡陋地”訓練了起來,經過了耗費時間的10萬局遊戲過後,卻效果平平,尤其是初始狀態的數值表現和預期相差不小。我想主要原因就是沒有采用等價局面同步更新的方法,導致資料利用率較低。等價局面有7個,分別是:旋轉90°,旋轉180°,旋轉270°,水平翻轉,垂直翻轉,旋轉90°+水平翻轉,旋轉90°+垂直翻轉,如下圖所示。另外,在生成等價局面的同時,也要生成等價的動作,這樣才能實現完整的Q值更新。

步驟1:寫旋轉和翻轉函式

def rotate(array): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
    list_ = list(array)
    list_[:] = map(list,zip(*list_[::-1])) 
    return np.array(list_) # Output: np.array [[7,4,1],[8,5,2],[9,6,3]]


def flip(array_, direction): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
    array = array_.copy()
    n = int(np.floor(len(array)/2))
    if direction == 'vertical': # Output: np.array [[7,8,9],[4,5,6],[1,2,3]]
        for i in range(n):
            temp = array[i].copy()
            array[i] = array[-i-1].copy()
            array[-i-1] = temp
    elif direction == 'horizon': # Output: np.array [[3,2,1],[6,5,4],[9,8,7]]
        for i in range(n):
            temp = array[:,i].copy()
            array[:,i] = array[:,-i-1]
            array[:,-i-1] = temp
    return array

步驟2:寫生成等價局面及等價動作的函式

函式名為 genEqualStateAndAction(state, action),定義在 Agent() 類中。

def genEqualStateAndAction(self, state_, action_): # Input: np.array, tuple(x,y)
        state, action = state_.copy(), action_
        equalStates, equalActions = [], []
        
        # 原局面
        equalStates.append(state)
        equalActions.append(action)
        
        # 水平翻轉
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        state_tf = flip(state_tf, 'horizon')
        action_state_tf = flip(action_state_tf, 'horizon')
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 垂直翻轉
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        state_tf = flip(state_tf, 'vertical')
        action_state_tf = flip(action_state_tf, 'vertical')
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉90°
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(1):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉180°
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(2):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉270°
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(3):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉90° + 水平翻轉
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(1):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        state_tf = flip(state_tf, 'horizon')
        action_state_tf = flip(action_state_tf, 'horizon')
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉90° + 垂直翻轉
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(1):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        state_tf = flip(state_tf, 'vertical')
        action_state_tf = flip(action_state_tf, 'vertical')
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
                
        return equalStates, equalActions

細心的讀者可能會發問了:你這生成等價局面不去重的麼?是的,不去重了。原因之一是如果要去重,那麼要比對大量的np.array,實現起來較麻煩,可能會增加很多程式碼時間;原因之二是對重複的局面多次更新,只是不符合邏輯,但應該沒有副作用:畢竟只要資料夠多,最後Q表中的值都會收斂到一個值,而重複出現次數多的局面只是收斂得更快罷了。

步驟3:修改Agent()中的相關程式碼

需要修改方法 addNewState(self, env_, currentMove) 和方法 updateQtable(self, env_, currentMove, done_),整體程式碼如下:

強化學習實戰 | 表格型Q-Learning玩井子棋(三)優化,優化
import gym
import random
import time
import numpy as np

# 檢視所有已註冊的環境
# from gym import envs
# print(envs.registry.all()) 

def str2tuple(string): # Input: '(1,1)'
    string2list = list(string)
    return ( int(string2list[1]), int(string2list[4]) ) # Output: (1,1)


def rotate(array): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
    list_ = list(array)
    list_[:] = map(list,zip(*list_[::-1])) 
    return np.array(list_) # Output: np.array [[7,4,1],[8,5,2],[9,6,3]]


def flip(array_, direction): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
    array = array_.copy()
    n = int(np.floor(len(array)/2))
    if direction == 'vertical': # Output: np.array [[7,8,9],[4,5,6],[1,2,3]]
        for i in range(n):
            temp = array[i].copy()
            array[i] = array[-i-1].copy()
            array[-i-1] = temp
    elif direction == 'horizon': # Output: np.array [[3,2,1],[6,5,4],[9,8,7]]
        for i in range(n):
            temp = array[:,i].copy()
            array[:,i] = array[:,-i-1]
            array[:,-i-1] = temp
    return array


class Game():
    def __init__(self, env):
        self.INTERVAL = 0 # 行動間隔
        self.RENDER = False # 是否顯示遊戲過程
        self.first = 'blue' if random.random() > 0.5 else 'red' # 隨機先後手
        self.currentMove = self.first
        self.env = env
        self.agent = Agent()
    
    
    def switchMove(self): # 切換行動玩家
        move = self.currentMove
        if move == 'blue': self.currentMove = 'red'
        elif move == 'red': self.currentMove = 'blue'
    
    
    def newGame(self): # 新建遊戲
        self.first = 'blue' if random.random() > 0.5 else 'red'
        self.currentMove = self.first
        self.env.reset()
        self.agent.reset()
    
    
    def run(self): # 玩一局遊戲
        self.env.reset() # 在第一次step前要先重置環境,不然會報錯
        while True:
            print(f'--currentMove: {self.currentMove}--')
            self.agent.updateQtable(self.env, self.currentMove, False)
            
            if self.currentMove == 'blue':
                self.agent.lastState_blue = self.env.state.copy()
            elif self.currentMove == 'red':
                self.agent.lastState_red = self.agent.overTurn(self.env.state) # 紅方視角需將狀態翻轉
                
            action = self.agent.epsilon_greedy(self.env, self.currentMove)
            if self.currentMove == 'blue':
                self.agent.lastAction_blue = action['pos']
            elif self.currentMove == 'red':
                self.agent.lastAction_red = action['pos']
            
            state, reward, done, info = self.env.step(action)
            if done:
                self.agent.lastReward_blue = reward
                self.agent.lastReward_red = -1 * reward
                self.agent.updateQtable(self.env, self.currentMove, True)
            else:     
                if self.currentMove == 'blue':
                    self.agent.lastReward_blue = reward
                elif self.currentMove == 'red':
                    self.agent.lastReward_red = -1 * reward
            
            if self.RENDER: self.env.render()
            self.switchMove()
            time.sleep(self.INTERVAL)
            if done:
                self.newGame()
                if self.RENDER: self.env.render()
                time.sleep(self.INTERVAL)
                break
                    
class Agent():
    def __init__(self):
        self.Q_table = {}
        self.EPSILON = 0.05
        self.ALPHA = 0.5
        self.GAMMA = 1 # 折扣因子
        self.lastState_blue = None
        self.lastAction_blue = None
        self.lastReward_blue = None
        self.lastState_red = None
        self.lastAction_red = None
        self.lastReward_red = None
    
    
    def reset(self):
        self.lastState_blue = None
        self.lastAction_blue = None
        self.lastReward_blue = None
        self.lastState_red = None
        self.lastAction_red = None
        self.lastReward_red = None
    
    
    def getEmptyPos(self, state): # 返回空位的座標
        action_space = []
        for i, row in enumerate(state):
            for j, one in enumerate(row):
                if one == 0: action_space.append((i,j)) 
        return action_space
    
    
    def randomAction(self, env_, mark): # 隨機選擇空格動作
        actions = self.getEmptyPos(env_)
        action_pos = random.choice(actions)
        action = {'mark':mark, 'pos':action_pos}
        return action
    
    
    def overTurn(self, state): # 翻轉狀態
        state_ = state.copy()
        for i, row in enumerate(state_):
            for j, one in enumerate(row):
                if one != 0: state_[i][j] *= -1
        return state_
    
    
    def genEqualStateAndAction(self, state_, action_): # Input: np.array, tuple(x,y)
        state, action = state_.copy(), action_
        equalStates, equalActions = [], []
        
        # 原局面
        equalStates.append(state)
        equalActions.append(action)
        
        # 水平翻轉
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        state_tf = flip(state_tf, 'horizon')
        action_state_tf = flip(action_state_tf, 'horizon')
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 垂直翻轉
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        state_tf = flip(state_tf, 'vertical')
        action_state_tf = flip(action_state_tf, 'vertical')
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉90°
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(1):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉180°
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(2):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉270°
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(3):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉90° + 水平翻轉
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(1):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        state_tf = flip(state_tf, 'horizon')
        action_state_tf = flip(action_state_tf, 'horizon')
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
        
        # 旋轉90° + 垂直翻轉
        state_tf = state.copy()
        action_state_tf = np.zeros(state.shape)
        action_state_tf[action] = 1
        for i in range(1):
            state_tf = rotate(state_tf)
            action_state_tf = rotate(action_state_tf)
        state_tf = flip(state_tf, 'vertical')
        action_state_tf = flip(action_state_tf, 'vertical')
        index = np.where(action_state_tf == 1)
        action_tf = (int(index[0]), int(index[1]))
        equalStates.append(state_tf)
        equalActions.append(action_tf)
                
        return equalStates, equalActions
    
    
    def addNewState(self, env_, currentMove): # 若當前狀態不在Q表中,則新增狀態
         state = env_.state if currentMove == 'blue' else self.overTurn(env_.state) # 如果是紅方行動則翻轉狀態
         eqStates, eqActions = self.genEqualStateAndAction(state, (0,0))
         
         for one in eqStates:
             if str(one) not in self.Q_table:
                 self.Q_table[str(one)] = {}
                 actions = self.getEmptyPos(one)
                 for action in actions:
                     self.Q_table[str(one)][str(action)] = 0
    
        
    def epsilon_greedy(self, env_, currentMove): # ε-貪心策略
        state = env_.state if currentMove == 'blue' else self.overTurn(env_.state) # 如果是紅方行動則翻轉狀態
        Q_Sa = self.Q_table[str(state)]
        maxAction, maxValue, otherAction = [], -100, [] 
        for one in Q_Sa:
            if Q_Sa[one] > maxValue:
                maxValue = Q_Sa[one]
        for one in Q_Sa:
            if Q_Sa[one] == maxValue:
                maxAction.append(str2tuple(one))
            else:
                otherAction.append(str2tuple(one))
        
        try:
            action_pos = random.choice(maxAction) if random.random() > self.EPSILON else random.choice(otherAction)
        except: # 處理從空的otherAction中取值的情況
            action_pos = random.choice(maxAction) 
        action = {'mark':currentMove, 'pos':action_pos}
        return action
    
    
    def updateQtable(self, env_, currentMove, done_):
        
        judge = (currentMove == 'blue' and self.lastState_blue is None) or \
                (currentMove == 'red' and self.lastState_red is None)
        if judge: # 邊界情況1:若agent無上一狀態,說明是遊戲中首次動作,那麼只需要新增狀態就好,無需更新Q值
            self.addNewState(env_, currentMove)
            return
                
        if done_: # 邊界情況2:若當前狀態S_是終止狀態,則無需把S_新增至Q表格中,直接令maxQ_S_a = 0,並同時更新雙方Q值
            for one in ['blue', 'red']:
                S = self.lastState_blue  if one == 'blue' else self.lastState_red
                a = self.lastAction_blue if one == 'blue' else self.lastAction_red
                eqStates, eqActions = self.genEqualStateAndAction(S, a)
                R = self.lastReward_blue if one == 'blue' else self.lastReward_red
                # print('lastState S:\n', S)
                # print('lastAction a: ', a)
                # print('lastReward R: ', R)
                # print('\n')
                maxQ_S_a = 0
                for S, a in zip(eqStates, eqActions):
                    self.Q_table[str(S)][str(a)] = (1 - self.ALPHA) * self.Q_table[str(S)][str(a)] \
                                                    + self.ALPHA * (R + self.GAMMA * maxQ_S_a)
            return
          
        # 其他情況下:Q表無當前狀態則新增狀態,否則直接更新Q值
        self.addNewState(env_, currentMove)
        S_ = env_.state if currentMove == 'blue' else self.overTurn(env_.state)
        S = self.lastState_blue  if currentMove == 'blue' else self.lastState_red
        a = self.lastAction_blue if currentMove == 'blue' else self.lastAction_red
        eqStates, eqActions = self.genEqualStateAndAction(S, a)
        R = self.lastReward_blue if currentMove == 'blue' else self.lastReward_red
        # print('lastState S:\n', S)
        # print('State S_:\n', S_)
        # print('lastAction a: ', a)
        # print('lastReward R: ', R)
        # print('\n')
        Q_S_a = self.Q_table[str(S_)]
        maxQ_S_a = -100 
        for one in Q_S_a:
            if Q_S_a[one] > maxQ_S_a:
                maxQ_S_a = Q_S_a[one]
        for S, a in zip(eqStates, eqActions): 
            self.Q_table[str(S)][str(a)] = (1 - self.ALPHA) * self.Q_table[str(S)][str(a)] \
                                            + self.ALPHA * (R + self.GAMMA * maxQ_S_a)
                                            
                                            
env = gym.make('TicTacToeEnv-v0')
game = Game(env)
for i in range(10000):
    print('episode', i)
    game.run()
Q_table = game.agent.Q_table
View Code

測試

經過了上述優化,agent能夠在一輪對局中更新16個Q值,比起上一節 強化學習實戰 | 表格型Q-Learning玩井字棋(二)開始訓練! 中的更新2個Q值要多8倍,不妨就玩1萬局遊戲,看看是否能玩出之前玩8萬局遊戲的效果。

專案1:檢視Q表格的狀態數

 一般般,仍然有狀態沒有覆蓋到。

專案2:檢視初始狀態

先手開局:

這效果也太好了吧!不但有完美的對稱,還有涇渭分明的勝負判斷: 第一步走四邊就穩了,走四角和走中間都是輸面大。看來優化之後,Q值的整體方差這一塊表現得非常好了。

再貼一個後手開局的情況:

專案3:測試程式碼時間

引入了更復雜的trick,確實是完美地爭取到了一些收益,但玩一局遊戲的時間一定是增加了,增加了多少呢?我們用上一節的老演算法和本節的演算法分別跑2000局遊戲,記錄一下時間(本人使用的CPU是:Intel(R) Core(TM) i7-9750H)。

雙向更新+等價局面同步更新:

 雙向更新:

增加了不到兩倍的時間,換來了大約8倍的更新量提高,還降低了方差,看來這優化是賺的。

小結

拿著優化好的演算法,心裡也有了些底氣,可以放心大膽地增加訓練時間了。下一節,我們將用訓練完全Q表,用pygame做一個擁有人機對陣,機機對戰,作弊功能的井字棋遊戲。還可以做一些對戰的資料分析,比如AI內戰的勝率多高?AI對陣隨機策略的勝率多高?下節見!

 

 

 

 

相關文章