強化學習實戰 | 自定義Gym環境之掃雷

埠默笙聲聲聲脈發表於2022-01-26

開始之前

先考慮幾個問題:

  • Q1:如何展開無雷區?
  • Q2:如何計算格子的提示數?
  • Q3:如何表示掃雷遊戲的狀態?

A1:可以使用遞迴函式,或是堆疊。

A2:一般的做法是,需要開啟某格子時,再去統計周圍的雷數。如果有方便的二維卷積函式可以呼叫,這會是個更簡潔的方法:

$$\begin{bmatrix}
1 & 0 & 0 & 1 & 0\\
0 & 1 & 0 & 0 & 1\\
1 & 0 & 1 & 0 & 0\\
0 & 0 & 0 & 0 & 0\\
0 & 1 & 0 & 0 & 1
\end{bmatrix}\bigstar
\begin{bmatrix}
1 & 1 & 1\\
1 & 0 & 1\\
1 & 1 & 1
\end{bmatrix}=
\begin{bmatrix}
1 & 2 & 2 & 1 & 2\\
3 & 3 & 3 & 3 & 1\\
1 & 3 & 1 & 2 & 1\\
2 & 3 & 2 & 2 & 1\\
1 & 0 & 1 & 1 & 0
\end{bmatrix}$$

 不妨用 $\bigstar$ 表示二維卷積運算。等號左邊的5×5矩陣表示了雷的分佈情況,值1表示有雷,值0表示無雷;等號左邊的3×3矩陣是求解周圍雷數的卷積核(或稱濾波器,特徵提取器);等號右邊的矩陣即是所有格子的周圍雷數。

程式碼實現起來也非常簡單:

from scipy import signal
import numpy as np
state_mine = np.array([[1,0,0,1,0],[0,1,0,0,1],[1,0,1,0,0],[0,0,0,0,0],[0,1,0,0,1]])
KERNAL = np.array([[1,1,1],[1,0,1],[1,1,1]])
state_num = signal.convolve2d(state_mine, KERNAL, 'same')

A3:對於玩家來說,遊戲狀態是不完全觀測的,也即需要區分觀測狀態環境狀態。環境狀態包括雷分佈矩陣,和提示數矩陣(也即上式提到的);觀測狀態是玩家部分可見的環境狀態,需要根據格子的開啟狀態對雷分佈矩陣進行部分遮蔽。觀測狀態不包括雷分佈矩陣,因為一旦觸雷即遊戲結束,所以遊戲中所有非終止狀態都是無雷的。

那麼對於一個大小為$M \times N$的掃雷遊戲,環境狀態可以表示為 $M \times N \times 2$ 的張量:頻道1是雷分佈矩陣,頻道2是提示數矩陣;觀測狀態可以表示為 $M \times N \times 2$ 的張量:頻道1是表示格子開啟狀態的矩陣(值1為開啟,值0為未開啟),並以此矩陣對 提示數矩陣 進行元素乘,完成對環境狀態的部分遮蔽,作為第二個頻道。對於numpy.array而言,元素乘是容易的:

observe_num = state_num * state_open

以下圖的遊戲狀態為例說明:

環境狀態為:

$$\begin{bmatrix}
 &  &  &  & \\
 & 1 &  &  & \\
 &  &  &  & \\
 &  &  & 1 & \\
1 & 1 &  &  & 
\end{bmatrix}\times
\begin{bmatrix}
1 & 1 & 1 & 0 & 0\\
1 & 0 & 1 & 0 & 0\\
1 & 1 & 2 & 1 & 1\\
2 & 2 & 2 & 0 & 1\\
1 & 1 & 2 & 1 & 1
\end{bmatrix}$$

觀測狀態為:

$$\begin{bmatrix}
1 & 0 & 1 & 0 & 0\\
1 & 0 & 1 & 0 & 0\\
1 & 0 & 2 & 1 & 1\\
2 & 2 & 0 & 0 & 1\\
1 & 0 & 0 & 1 & 0
\end{bmatrix}\times
\begin{bmatrix}
1 &  & 1 & 1 & 1\\
1 &  & 1 & 1 & 1\\
1 &  & 1 & 1 & 1\\
1 & 1 &  &  & 1\\
1 &  &  & 1 & 
\end{bmatrix}$$

 但這種表示方式不是唯一的,比如我們可以把提示數矩陣拆成9個頻道,分別表示0~8的提示數。那麼觀測狀態就變成了 $M \times N \times 10$ 的張量:

$$\begin{bmatrix}
& & & 1 & 1\\
& & & 1 & 1\\
& & & & \\
& & & & \\
& & & &
\end{bmatrix}\times
\begin{bmatrix}
1 & & 1 & & \\
1 & & 1 & & \\
1 & & & 1 & 1\\
& & & & 1\\
& & & 1 &
\end{bmatrix}\times
\begin{bmatrix}
& & & & \\
& & & & \\
& & 1 & & \\
1 & 1 & & & \\
& & & &
\end{bmatrix}\times
\begin{bmatrix}
& & & & \\
& & & & \\
& & & & \\
& & & & \\
& & & &
\end{bmatrix}\times
\cdots \times
\begin{bmatrix}
& & & & \\
& & & & \\
& & & & \\
& & & & \\
& & & &
\end{bmatrix}\times
\begin{bmatrix}
1 & & 1 & 1 & 1\\
1 & & 1 & 1 & 1\\
1 & & 1 & 1 & 1\\
1 & 1 & & & 1\\
1 & & & 1 &
\end{bmatrix}$$

狀態空間的設計是靈活的,唯一的評價的標準是完整的學習系統的效能表現。如果採用以上多頻道式的狀態空間設計,那麼後續可以很方便地使用卷積神經網路開展學習任務。你也可以把張量陣展成一維的向量,然後用全連線神經網路處理。本文後續的實現將採用 $M \times N \times 2$ 的狀態空間表達。

步驟1:新建檔案

為了執行pytorch,我使用anaconda的環境管理操作建立了名為pytorch1.1的環境名,並在這個環境下安裝了openAI gym,因此我來到目錄:D:\Anaconda\envs\pytorch1.1\Lib\site-packages\gym\envs\user 下,新建檔案 __init__.pyMineSweeper_env.py。

步驟2:編寫檔案 MineSweeper_env.py

一個標準的gym env類包含三個方法:reset(),step(action),和render()。

  • reset() 用於初始化環境;
  • step(action) 有四個返回值:state,reward,done,和info,因此我們需要在該函式中完成掃雷遊戲的全部邏輯;
  • render() 用於視覺化環境。我在網上沒有找到gym的原生方法rendering可以顯示文字的說法(如果有知曉的朋友請留言,感謝!),所以是通過pyglet + 動態變數名的方式實現大量字元的顯示,具體做法可見 強化學習實戰 | 自定義Gym環境之顯示字串

MineSweeper_env.py 的整體程式碼如下:

import gym
import random
import time
import numpy as np
from scipy import signal # 二維卷積
import pyglet # 顯示文字
from gym.envs.classic_control import rendering


class DrawText: # 用於在rendering中顯示文字
    def __init__(self, label:pyglet.text.Label):
        self.label=label
    def render(self):
        self.label.draw()


class MineSweeperEnv(gym.Env):
    def __init__(self):
        self.MINE_NUM = 20
        self.ROW, self.COL = 12, 12
        self.SIZE = 40
        WIDTH = self.COL * self.SIZE
        HEIGHT = self.ROW * self.SIZE
        self.viewer = rendering.Viewer(WIDTH, HEIGHT)
        self.state_mine = None
        self.state_num = None
        self.state_open = None
        self.gameOver = False
        
        
    def reset(self):
        # 初始化:佈雷狀態
        MINE_NUM = self.MINE_NUM
        self.state_mine = np.zeros(self.ROW * self.COL) 
        self.state_mine[:MINE_NUM] = 1
        random.shuffle(self.state_mine)
        self.state_mine = self.state_mine.reshape(self.ROW, self.COL)
        # 初始化:提示數字
        KERNAL = np.array([[1,1,1], [1,0,1], [1,1,1]])
        self.state_num = signal.convolve2d(self.state_mine, KERNAL, 'same')
        # 初始化:開啟狀態
        self.state_open = np.zeros((self.ROW, self.COL))
        # 初始化:遊戲是否結束
        self.gameOver = False
        
    
    def getRoundSet(self, x, y):
        roundSet = []
        for i in range(x-1, x+2):
            for j in range(y-1, y+2):
                if 0 <= i < self.ROW and 0 <= j < self.COL and (i, j) != (x, y):
                    roundSet.append((i, j))
        return roundSet
    
    
    def step(self, action):
        # 執行動作
        x, y = action
        # 若開啟數字不為0
        if self.state_num[x, y] >= 1:
            self.state_open[x, y] = 1
        # 若開啟數字為0 則展開無雷區
        if self.state_num[x, y] == 0:
            stack = []
            stack.append((x, y))
            while len(stack):
                row, col = stack.pop()
                self.state_open[row, col] = 1
                for one in self.getRoundSet(row, col):
                    # 排除已經開啟的格子
                    if self.state_open[one] == 1:
                        continue
                    if self.state_num[one] >= 1:
                        self.state_open[one] = 1
                    else:
                        stack.append(one)         
    
        # 是否獲勝或失敗/獲得獎勵
        done, reward = False, 0
        # 若開啟雷 則遊戲失敗
        if self.state_mine[x, y] == 1:
            self.state_open[x, y] = 1
            self.gameOver = True
            done, reward = True, -1
        # 若剩餘未開啟的格子數 = 雷數 則獲勝
        if ROW*COL - self.state_open.sum() == self.MINE_NUM:
            self.gameOver = True
            done, reward = True, 1
        
        # 報告(維持gym step的標準格式)
        info = {}
        # 觀測狀態
        observe_num = self.state_num * self.state_open
        observe = [observe_num, self.state_open]
        return observe, reward, done, info
    
    
    def render(self, mode='human'):
        ROW, COL, SIZE = self.ROW, self.COL, self.SIZE
        # 畫方塊
        for i in range(ROW):
            for j in range(COL):
                X, Y = j*SIZE, (ROW-i-1)*SIZE
                tile = rendering.make_polygon([(X,Y), (X+SIZE,Y), (X+SIZE,Y+SIZE), (X,Y+SIZE)], filled=True)
                if self.state_open[i,j] == 0:
                    tile.set_color(106/255,116/255,166/255)
                if self.state_open[i,j] == 1 and self.state_mine[i,j] == 0:
                    tile.set_color(255/255,242/255,204/255)
                if self.state_open[i,j] == 1 and self.state_mine[i,j] == 1:
                    tile.set_color(220/255,20/255,60/255)
                self.viewer.add_geom(tile)
        # 畫分隔線
        WIDTH = COL*SIZE
        HEIGHT = ROW*SIZE
        for i in range(ROW+1):
            line = rendering.Line((0, i*SIZE), (WIDTH, i*SIZE))
            line.set_color(80/255, 80/255, 80/255)
            self.viewer.add_geom(line)
        for j in range(COL+1):
            line = rendering.Line((j*SIZE, 0), (j*SIZE, HEIGHT))
            line.set_color(80/255, 80/255, 80/255)
            self.viewer.add_geom(line)
        # 畫數字
        for i in range(ROW):
            for j in range(COL):
                exec('label_{}_{} = {}'.format(i, j, None))
                names = locals()
                NUM = int(self.state_num[i,j])
                COLOR = (255, 255, 255, 255)
                if NUM == 1:
                    COLOR = (46, 117, 182, 255)
                elif NUM == 2:
                    COLOR = (84, 130, 53, 255)
                elif NUM == 3:
                    COLOR = (192, 0, 0, 255)
                elif NUM == 4:
                    COLOR = (112, 48, 160, 255)
                elif NUM == 5:
                    COLOR = (132, 60, 12, 255)
                elif NUM == 6:
                    COLOR = (191, 144, 0, 255)
                elif NUM == 7:
                    COLOR = (32, 56, 100, 255)
                elif NUM == 8:
                    COLOR = (13, 13, 13, 255)
                names['label_' + str(i) + '_' + str(j)] = pyglet.text.Label('{}'.format(NUM), font_size=15,
                                  x=(j+0.32)*SIZE, y=(ROW-i-1+0.23)*SIZE, anchor_x='left', anchor_y='bottom',
                                  color=COLOR)
                label = names['label_{}_{}'.format(i, j)]
                label.draw()
                if self.state_mine[i,j] == 0 and self.state_open[i,j] == 1 and self.state_num[i,j] >= 1:
                    self.viewer.add_geom(DrawText(label))
                # 畫雷
                if self.gameOver == True:
                    if self.state_mine[i,j] == 1:
                        mine = rendering.make_circle(10, 6, filled=True)
                        mine.set_color(30/255, 30/255, 30/255)
                        translation = rendering.Transform(translation=((j+0.5)*SIZE, (ROW-i-1+0.5)*SIZE))
                        mine.add_attr(translation)
                        self.viewer.add_geom(mine)
                
        return self.viewer.render(return_rgb_array=mode == 'rgb_array')
        

# 測試程式碼:以隨機策略執行動作
if __name__ == '__main__': 
    MineSweeper = MineSweeperEnv()
    ROW, COL = MineSweeper.ROW, MineSweeper.COL
    MineSweeper.reset()
    MineSweeper.render()
    while MineSweeper.gameOver is not True:
        while True:
            rand = random.choice(range(ROW*COL))
            x, y = rand//ROW, rand%ROW
            if MineSweeper.state_open[x, y] == 0:
                action = (x, y)
                break
        state, reward, done, info = MineSweeper.step(action)
        MineSweeper.render()
        time.sleep(0.5)

直接執行檔案,執行測試程式碼(以隨機策略執行動作):

步驟3:編寫 __init__.py

在 __init__.py 中引入類的資訊,新增:

from gym.envs.user.MineSweeper_env import MineSweeperEnv

步驟4:註冊環境

來到目錄:D:\Anaconda\envs\pytorch1.1\Lib\site-packages\gym,開啟 __init__.py,新增程式碼:

register(
    id="MineSweeperEnv-v0",
    entry_point="gym.envs.user:MineSweeperEnv",
    max_episode_steps=200,    
)

步驟5:測試環境

在相同的conda環境下,輸入程式碼:

import gym
env = gym.make('MineSweeperEnv-v0')
env.reset()
env.render()

若無報錯,則說明gym環境註冊成功。

 

相關文章