使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車

機器之心發表於2018-03-13

對機器學習感興趣的人都知道基於人工智慧的強化學習的能力。過去的紀念見證了很多使用強化學習(RL)做出的突破。DeepMind 將強化學習與機器學習相結合,在很多 Atari 遊戲中達到了超越人類的結果,並且在 2016 年 3 月的時候以 4:1 的成績擊敗了圍棋冠軍李世石。儘管強化學習目前在很多遊戲環境中超越了人類,但是用它來解決一些需要最優決策和效率的問題還是比較新穎的,而且強化學習也會在未來的機器智慧方面起到重要的作用。

簡單地解釋,強化學習就是智慧體通過採取行動與環境互動以嘗試最大化所得的積累獎勵的計算方法。下面是一張簡圖(智慧體—環境迴圈),圖來自於強化學習簡介(第二版)(Reinforcement Learning: An Introduction 2nd Edition,http://incompleteideas.net/sutton/book/the-book-2nd.html)。

使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車

在當前狀態下 St 的智慧體採取行動 At,環境對這個動作進行互動和回應,返回一個新的狀態 S(t+1),並給智慧體一個 R(t+1) 的獎勵。然後智慧體選擇下一個行動,然後這個迴圈一直重複,直到問題被解決了或者被中斷了。

最近強化學習被用來解決一些挑戰性問題,從遊戲到機器人學。在工業應用中,強化學習也開始作為一個實際元件出現,例如資料中心冷卻。而強化學習的絕大多數成功則來自於單智慧體領域,在單智慧體領域中不需要對其他行動者的行為進行建模和預測。

然而,也存在很多涉及多智慧體互動的重要應用,其中會出現共同進化的智慧體的新興行為和複雜度。例如,多機器人控制,通訊和語言的發現,多個玩家參與的遊戲,以及對社會困境的分析都會涉及多智慧體領域。相關的問題也可以以不同的級別和水平來等同於多智慧體問題,例如分層強化學習的變體也可以被看做多智慧體系統。此外,多智慧體自我模擬最近也被證明是一個有用的訓練正規化。在構建能夠與人類有效互動的人工智慧系統時,將強化學習成功地擴充套件到多智慧體問題中是很關鍵的。

不幸的是,Q-learning 和策略梯度等傳統的強化學習方法不能很好地適應於多智慧體環境。一個問題是,每個智慧體的策略都是隨訓練過程發生變化的,並且在任意單獨的智慧體的角度看來,環境會變得不穩定(以一種在智慧體自己的策略中沒有解釋解釋得方式)。這就帶來了學習穩定性的挑戰,以及避免了對過去經驗回放(experience replay)的直接利用,經驗回放對穩定深度 Q 學習是很關鍵的。另一方面,當需要多智慧體協作的時候,策略梯度方法會表現出較高的方差。

或者,還可以使用基於模型的策略優化,這種優化可以通過反向傳播學到最佳策略,但是這需要一個不同的關於環境變化的模型以及關於智慧體之間互動的假設。從優化角度來看,將這些方法應用在競爭環境中也是有挑戰的,其在對抗訓練方法中已被證明存在高度的不穩定性。

卷積神經網路

卷積神經網路革新了模式識別。在卷積神經網路廣泛使用之前,絕大多數模式識別任務都是用初始階段的人工特徵加一個分類器執行的。

卷積神經網路的突破就是:所有的特徵都是從樣本自動學習到的。卷積神經網路在圖下個識別任務上特別強大,因為卷積運算能夠捕獲影像的二維本質特點。此外,通過使用卷積核來掃描整個圖片,與所有的運算次數相比,最終學到的引數會相對少一些。

儘管卷積神經網路在商業應用中已經有超過 20 年的歷史了,但是對它們的應用在最近幾年才開始爆發(這是因為以下兩個進展):

  • 有了便於使用的大規模的標籤資料集,例如大規模視覺識別挑戰(Large Scale Visual Recognition Challenge,ILSVRC)。

  • 在大規模並行的 GPU 上實現了卷積神經網路,這極大的加快了學習和推理過程的速度。

在這篇文章中,我們描述的卷積神經網路已經超出了簡單模式識別的範疇。它能夠學習到控制一輛自動汽車所需的所有過程。

使用卷積神經網路和 OpenAI Gym,我們可以建立一個多智慧體的系統,這些模型可以自動駕駛馬里奧賽車,並且彼此競爭。

使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車

在 Royal Raceway 賽道上(未訓練)進行駕駛

必要條件和安裝

  • Ubuntu

  • 最新支援 CUDA 的 NVIDIA GPU

  • mupen64plus N64 模擬器

  • MarioKart(馬里奧賽車)N64 ROM

  • OpenAI Gym

  • Mupen64Plus Gym 環境

  • tensorflow-gpu

OpenAI Gym

OpenAI Gym 是用來開發和對比強化學習演算法的工具箱

在強化學習中有兩個基本的概念:環境(也就是智慧體所處的外部世界)和智慧體(也就是所開發的演算法)。智慧體向環境傳送行動,環境回應一個觀察和獎勵(也就是一個分數)。

核心 gym 介面是一個 Env(https://github.com/openai/gym/blob/master/gym/core.py)。這裡沒有提供智慧體的介面,需要你去開發。下面是你應該知道的 Env 的方法:

  • reset(self):重設環境狀態。返回狀態觀察。

  • step(self, action):將環境推進一個時間步長。返回觀察值、獎勵、done 和 info。

  • render(self, mode=』human』, close=False):提交一幀環境。預設的模式會做一些友好的事情,例如彈出一個視窗。將最近的標誌訊號傳送到這種視窗。

錄製和訓練

我們開發了一個 python 指令碼來捕捉模擬器的螢幕,以及遊戲手柄和操縱桿的位置。

這個指令碼可以讓我們選擇用來儲存訓練資料的資料夾。它也能夠實時的繪製出被觸發的遊戲手柄的命令。

使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車

記錄玩家的行動

這個指令碼主要使用輸入 python 模組來記錄被按下去的按鈕以及操縱桿的位置。我使用 PS4 DualShock 4 手柄來訓練這個模型。記錄部分是按照這樣配置的:我們每隔 200ms 對模擬器進行一次截圖操作。

為了開始記錄,你必須遵循以下的步驟:

1. 啟動你的模擬器程式(mupen64plus),執行 Mario Kart 64。

2. 保證你將操縱桿連線好了,並且 t mupen64plus 在使用簡易直控媒體層(SDL)外掛。

3. 執行 record.py 指令碼

4. 確保圖形返回相應操縱桿的輸入操作

5. 將模擬器置於合適的位置(左上角),以保證程式能夠擷取到影像。

6. 開始記錄,並且將某一個級別的遊戲玩遍。

訓練資料

record.py 指令碼會將你所玩過的一個級別的所有級別的截圖儲存下來,同時還儲存了一個 data.csv 的檔案。

使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車訓練資料

data.csv 包含與玩家所使用的一系列控制操作相關的圖片的路徑連結。

使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車

data.csv

當我們擷取到了足夠多的訓練資料之後,下一步就是開始實際的訓練。開始訓練之前,需要準備一下我們的資料。

執行 utils.py 來準備樣本:用一個由樣本的路徑組成的陣列來構建用來訓練模型的矩陣 X 和 y。zsh 會擴充套件到所有的樣本路徑。傳遞一個全域性路徑也是可以的.

X 是三維影像矩陣。

y 是期望的操縱桿輸出陣列

 [0] joystick x axis
  [1] joystick y axis
  [2] button a
  [3] button b
  [4] button rb

模型

我們訓練一個卷積神經網路模型,作為從原始單個前置攝像頭直接到控制命令的對映。最終結果證明這個端到端的方法識特別強大的。使用最少的人類玩家的訓練資料,這個系統就能夠學會在擁擠的道路上駕駛馬里奧賽車,無論是在有分道標誌還是沒有分道標誌的高速路上。它也能夠在沒有明顯視覺引導的區域進行操作,例如停車區或者沒有鋪砌的道路上。

系統自動地學會了必要處理步驟的內部表徵,例如僅僅使用人類操縱的角度作為訓練訊號來檢測有用道路特徵。我們從未顯式地訓練它去做這種檢測,例如,使用道路輪廓。

使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車

訓練神經網路

與這個問題的顯式分解(例如道路標誌檢測、路徑規劃好和控制)相比,我們的端到端系統對這些處理步驟做了同時優化。我們認為這種處理機制會帶來更好的效能和更小規模的系統。會得到更好效能的原因是內部元件自優化到更好的全域性系統性,而不是去優化由人類選擇的中間判斷標準,例如道路檢測。這種標準當然是便於人類理解的,但是它不能自動化地保證最大化系統的效能。而最終網路規模會比較小的原因可能是系統在更少數量的步驟內學會了解決這個問題。

網路架構

我們的網路結構是對 NVIDIA 這篇論文中所描述的結構的實現 (https://arxiv.org/pdf/1604.07316.pdf)。

我們訓練網路權重引數,以最小化網路控制命令的結果和其他人類駕駛員的控制輸出之間的均方差,或者為偏離中心或者發生旋轉的圖片調整命令。我們的網路結構如圖 4 所示。網路包含 9 層,包括一個正則化層,5 個卷積層和 3 個全連線層。輸入被分割成 YUV 色彩空間的平面,然後被傳遞到網路中。網路的第一層執行影像正則化的操作。正則化是硬編碼處理,它在學習的過程中不會被調整。在網路中執行正則化可以使得正則化方案隨著網路的結構而改變,並且可以通過 GPU 處理來加速。卷積層是被設計用來進行特徵提取的,是根據經驗從一系列的層配置中選擇的。我們在前三個卷積層中使用卷積核為 5×5,步進為 2×2 的步進卷積,在最後兩層使用卷積核為 3×3 的非步進卷積。在 5 個卷積層之後是三個全連線層,最終的輸出值是反轉半徑。全連線層被設計用來進行轉向控制,但是我們要注意,通過以端到端的形式訓練網路,所以不太可能明確地區別網路中的哪一部分屬於控制器,哪一部分屬於特徵提取器。

使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車

卷積神經網路結構。網路總共擁有大約 2700 萬個連線和 25 萬個引數

訓練

train.py 會基於 google 的 TensorFlow 訓練一個模型,同時會採用 cuDNN 為 GPU 加速。訓練會持續一段時間(大約 1 小時),具體耗時會因所用的訓練資料和系統環境相關。當訓練結束的時候,程式會將模型儲存在硬碟上。

單智慧體的自主駕駛

play.py 會使用 gym-mupen64plus 環境來執行在馬里奧賽車環境中對智慧體的訓練。這個環境會捕獲模擬器的截圖。這些影像會被送到模型中以獲取將要傳送的操縱桿命令。人工智慧操縱桿命令可以通過「LB」按鈕來重寫。

使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車

自主駕駛的 LuiGi 賽道

模型的訓練使用了以下的環境

  • Luigi 賽道上的 4 次競賽

  • Kalimari 沙漠賽道中的 2 次競賽

  • Mario 賽道上的兩次競賽

即使在小的資料集上訓練,模型有時候也能夠泛化到一個新賽道上(例如上述的 Royal Raceway)。

多智慧體的自主駕駛

為了讓智慧體自主駕駛,OpenAI mupen64plus gym 環境需要連線到一個「自動程式」輸入外掛。因為我們對多智慧體環境感興趣,所以我們需要一種能夠使用多智慧體輸入程式的方式。因此,基於 mupen64plus-input-bot 和官方外掛 API(https://github.com/mupen64plus/mupen64plus-core/blob/master/doc/emuwiki-api-doc/Mupen64Plus-v2.0-Plugin-API.mediawiki#input-plugin-api),我們建立了 4 個玩家輸入自動程式。自動輸入程式後面的主要思想就是一個 python 伺服器。它能夠傳送 JSON 命令,並把這些命令轉譯到較低水平的指令。

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <string.h>

#include <netinet/tcp.h>
#include <sys/socket.h>
#include <syspes.h>
#include <netinet/in.h>
#include <netdb.h>

#include "plugin.h"
#include "controller.h"

#include "json.h"
#include "json_tokener.h"

#define HOST "localhost"
#define CONTROLLER_PORT1 8082
#define CONTROLLER_PORT2 8083

int socket_connect(char *host, int portno) {
 struct hostent *server;
 struct sockaddr_in serv_addr;
 int sockfd;

 /* create the socket */
 sockfd = socket(AF_INET, SOCK_STREAM, 0);
 if (sockfd < 0) DebugMessage(M64MSG_ERROR, "ERROR opening socket");

 /* lookup the ip address */
 server = gethostbyname(host);
 if (server == NULL) DebugMessage(M64MSG_ERROR, "ERROR, no such host");

 /* fill in the structure */
 memset(&serv_addr, 0, sizeof(serv_addr));
 serv_addr.sin_family = AF_INET;
 serv_addr.sin_port = htons(portno);
 memcpy(&serv_addr.sin_addr.s_addr, server->h_addr, server->h_length);

 /* connect the socket */
 if (connect(sockfd, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) {
 DebugMessage(M64MSG_INFO, "ERROR connecting, please start bot server.");
 return -1;
 }


 return sockfd;
}

void clear_controller(Control) {
 controller[Control].buttons.R_DPAD = 0;
 controller[Control].buttons.L_DPAD = 0;
 controller[Control].buttons.D_DPAD = 0;
 controller[Control].buttons.U_DPAD = 0;
 controller[Control].buttons.START_BUTTON = 0;
 controller[Control].buttons.Z_TRIG = 0;
 controller[Control].buttons.B_BUTTON = 0;
 controller[Control].buttons.A_BUTTON = 0;
 controller[Control].buttons.R_CBUTTON = 0;
 controller[Control].buttons.L_CBUTTON = 0;
 controller[Control].buttons.D_CBUTTON = 0;
 controller[Control].buttons.U_CBUTTON = 0;
 controller[Control].buttons.R_TRIG = 0;
 controller[Control].buttons.L_TRIG = 0;
 controller[Control].buttons.X_AXIS = 0;
 controller[Control].buttons.Y_AXIS = 0;
}

void read_controller(int Control) {
 int port;

 // Depending on controller, select whether port 1 or port 2
 switch (Control) {
 case 0:
 port = CONTROLLER_PORT1;
 break;
 case 1:
 port = CONTROLLER_PORT2;
 break;
 default:
 port = CONTROLLER_PORT1;
 }

 if(Control == 1){
 // Ignore controller 1
 // return;
 }

 // DebugMessage(M64MSG_INFO, "Controller #%d listening on port %d", Control, port );

 int sockfd = socket_connect(HOST, port);

 if (sockfd == -1) {
 clear_controller(Control);
 return;
 }

 int bytes, sent, received, total;
 char message[1024], response[4096]; // allocate more space than required.
 sprintf(message, "GET / HTTP/1.0\r\n\r\n");

 /* print the request */
 #ifdef _DEBUG
 DebugMessage(M64MSG_INFO, "[REQUEST] PORT %d: %s", port, message);
 #endif


 /* send the request */
 total = strlen(message);
 sent = 0;
 do {
 bytes = write(sockfd, message + sent, total - sent);
 if (bytes < 0)
 DebugMessage(M64MSG_ERROR, "ERROR writing message to socket");
 if (bytes == 0)
 break;
 sent += bytes;
 } while (sent < total);

 /* receive the response */
 memset(response, 0, sizeof(response));
 total = sizeof(response) - 1;
 received = 0;
 do {
 bytes = read(sockfd, response + received, total - received);
 if (bytes < 0)
 DebugMessage(M64MSG_ERROR, "ERROR reading response from socket");
 if (bytes == 0)
 break;
 received += bytes;
 } while (received < total);

 if (received == total)
 DebugMessage(M64MSG_ERROR, "ERROR storing complete response from socket");

/* print the response */
#ifdef _DEBUG
 DebugMessage(M64MSG_INFO, "[RESPONSE] PORT %d: %s", port, response);
#endif

 /* parse the http response */
 char *body = strtok(response, "\n");
 for (int i = 0; i < 5; i++)
 body = strtok(NULL, "\n");

 /* parse the body of the response */
 json_object *jsonObj = json_tokener_parse(body);

/* print the object */
#ifdef _DEBUG
 DebugMessage(M64MSG_INFO, json_object_to_json_string(jsonObj));
#endif

 controller[Control].buttons.R_DPAD =
 json_object_get_int(json_object_object_get(jsonObj, "R_DPAD"));
 controller[Control].buttons.L_DPAD =
 json_object_get_int(json_object_object_get(jsonObj, "L_DPAD"));
 controller[Control].buttons.D_DPAD =
 json_object_get_int(json_object_object_get(jsonObj, "D_DPAD"));
 controller[Control].buttons.U_DPAD =
 json_object_get_int(json_object_object_get(jsonObj, "U_DPAD"));
 controller[Control].buttons.START_BUTTON =
 json_object_get_int(json_object_object_get(jsonObj, "START_BUTTON"));
 controller[Control].buttons.Z_TRIG =
 json_object_get_int(json_object_object_get(jsonObj, "Z_TRIG"));
 controller[Control].buttons.B_BUTTON =
 json_object_get_int(json_object_object_get(jsonObj, "B_BUTTON"));
 controller[Control].buttons.A_BUTTON =
 json_object_get_int(json_object_object_get(jsonObj, "A_BUTTON"));
 controller[Control].buttons.R_CBUTTON =
 json_object_get_int(json_object_object_get(jsonObj, "R_CBUTTON"));
 controller[Control].buttons.L_CBUTTON =
 json_object_get_int(json_object_object_get(jsonObj, "L_CBUTTON"));
 controller[Control].buttons.D_CBUTTON =
 json_object_get_int(json_object_object_get(jsonObj, "D_CBUTTON"));
 controller[Control].buttons.U_CBUTTON =
 json_object_get_int(json_object_object_get(jsonObj, "U_CBUTTON"));
 controller[Control].buttons.R_TRIG =
 json_object_get_int(json_object_object_get(jsonObj, "R_TRIG"));
 controller[Control].buttons.L_TRIG =
 json_object_get_int(json_object_object_get(jsonObj, "L_TRIG"));
 controller[Control].buttons.X_AXIS =
 json_object_get_int(json_object_object_get(jsonObj, "X_AXIS"));
 controller[Control].buttons.Y_AXIS =
 json_object_get_int(json_object_object_get(jsonObj, "Y_AXIS"));

 close(sockfd);
}

多控制器輸入程式

此外,為了支援多智慧體的 gym 環境,我們必須更新 gym-mupen64plus。因此,我們 fork 了官方的程式碼庫並建立了我們自己的 gym-mupen64plus。主要的區別都是跟步進/觀察/獎勵函式相關。我們需要一種能夠僅僅檢視一部分截圖,為了知道下一步應該採取什麼行動來獲得所需的資訊。

為了啟動多智慧體的馬里奧賽車,只需要執行:

play.py --num_agents=2

這個指令碼會根據智慧體的數目來建立 python 伺服器,並給每個智慧體分配埠。然後,使用 mupen64plus 環境 ttps://github.com/bzier/gym-mupen64plus),這個指令碼將會選擇隨機玩家來開始競賽。

現在,第一個智慧體得到了上述的 CNN 模型,同時第二個智慧體也得到了一個非常通用的 CNN,它包含 3 個卷積層和 2 個全連線層。每個模型都得到了一定比例的螢幕,然後預測操縱桿的位置和速度按鈕。

使用Gym和CNN構建多智慧體自動駕駛馬里奧賽車

值得一提的是,為了讓遊戲看起來比較平滑,我們為每個玩家建立了一個新執行緒(https://github.com/bzier/gym-mupen64plus)。每個執行緒必須提供基於全域性狀態的行動。

if __name__ == '__main__':
  env = gym.make('Mario-Kart-Luigi-Raceway-Multi-v0')
  obs = env.reset()
  env.render()

  while not end_episode:
    # Action should be multi-threaded + setting agent
    for i in range(num_agents):
        agent = i+1

        # Set current agent
        CurrentAgent.set_current_agent(agent)

        # CReate Thread
        thread = AgentThread(target=get_action, args=(agent, obs,actors[i],))
        thread.start()

        # Get action
        action = thread.join()
        # cprint('[Gym Thread] Got action %s for agent %i' %((action,agent)),'red')
        obs, reward, end_episode, info = env.step(action)

  env.render()
  #self.queue.put('render')
  total_reward += reward

從 main 檔案中抽取出的程式碼

哪個模型更好一些?

我們假設能夠讓智慧體贏得競賽的模型是更準確的,是效能更好的。

結論

我們建立了一個含有使用不同模型的多智慧體的系統,這些智慧體可以為了贏得比賽而相互競爭。當結合強化學習時,系統回答了一個關鍵問題:為了得到獎勵,什麼樣的模型效能最好。我們的系統可以被用作了解一個模型比其他模型好的基準工具。


原文連結:https://medium.com/@aymen.mouelhi/multi-agents-self-driving-mario-kart-with-tensorflow-and-cnns-c0f2812b4c50

相關文章