用 Java 訓練出一隻“不死鳥”

削微寒發表於2020-12-23

作者:Kingyu & Lanking

FlappyBird 是 2013 年推出的一款手機遊戲,因其簡單的玩法但極度困難的設定迅速走紅全網。隨著深度學習(DL)與增強學習(RL)等前沿演算法的發展,我們可以使用 Java 非常方便地訓練出一個智慧體來控制 Flappy Bird。

故事開始於《GitHub 上的大佬們打完招呼,會聊些什麼?》,那麼,今天我們就來一起看一下如何用 Java 訓練出一個不死鳥。遊戲專案我們使用了一個僅用 Java 基本類庫編寫的 FlappyBird 遊戲。在訓練方面,我們使用 DeepJavaLibrary 一個基於 Java 的深度學習框架來構建增強學習訓練網路並進行訓練。經過了差不多 300 萬步(四小時)的訓練後,小鳥已經可以獲得最高 8000 多分的成績,靈活穿梭於水管之間。

在本文中,我們將從原理開始一步一步實現增強學習演算法並用它對遊戲進行訓練。如果任何一個時刻不清楚如何繼續進行下去,可以參閱專案的原始碼。

專案地址:https://github.com/kingyuluk/RL-FlappyBird

增強學習(RL)的架構

在這一節會介紹主要用到的演算法以及神經網路,幫助你更好的瞭解如何進行訓練。本專案與 DeepLearningFlappyBird 使用了類似的方法進行訓練。演算法整體的架構是 Q-Learning + 卷積神經網路(CNN),把遊戲每一幀的狀態儲存起來,即小鳥採用的動作和採用動作之後的效果,這些將作為卷積神經網路的訓練資料。

CNN 訓練簡述

CNN 的輸入資料為連續的 4 幀影像,我們將這影像 stack 起來作為小鳥當前的“observation”,影像會轉換成灰度圖以減少所需的訓練資源。影像儲存的矩陣形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 陣列裡的元素就是當前幀的畫素值,這些資料將輸入到 CNN 後將輸出 (batch size, 2) 的矩陣,矩陣的第二個維度就是小鳥 (振翅不採取動作) 對應的收益。

訓練資料

在小鳥採取動作後,我們會得到 preObservation and currentObservation 即是兩組 4 幀的連續的影像表示小鳥動作前和動作後的狀態。然後我們將 preObservation, currentObservation, action, reward, terminal 組成的五元組作為一個 step 存進 replayBuffer 中。它是一個有限大小的訓練資料集,他會隨著最新的操作動態更新內容。

public void step(NDList action, boolean training) {
    if (action.singletonOrThrow().getInt(1) == 1) {
        bird.birdFlap();
    }
    stepFrame();
    NDList preObservation = currentObservation;
    currentObservation = createObservation(currentImg);
    FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
            preObservation, currentObservation, action, currentReward, currentTerminal);
    if (training) {
        replayBuffer.addStep(step);
    }
    if (gameState == GAME_OVER) {
        restartGame();
    }
}

訓練的三個週期

訓練分為 3 個不同的週期以更好地生成訓練資料:

  • Observe(觀察) 週期:隨機產生訓練資料
  • Explore (探索) 週期:隨機與推理動作結合更新訓練資料
  • Training (訓練) 週期:推理動作主導產生新資料

通過這種訓練模式,我們可以更好的達到預期效果。

處於 Explore 週期時,我們會根據權重選取隨機的動作或使用模型推理出的動作來作為小鳥的動作。訓練前期,隨機動作的權重會非常大,因為模型的決策十分不準確 (甚至不如隨機)。在訓練後期時,隨著模型學習的動作逐步增加,我們會不斷增加模型推理動作的權重並最終使它成為主導動作。調節隨機動作的引數叫做 epsilon 它會隨著訓練的過程不斷變化。

public NDList chooseAction(RlEnv env, boolean training) {
    if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
        return env.getActionSpace().randomAction();
    } else return baseAgent.chooseAction(env, training);
}

訓練邏輯

首先,我們會從 replayBuffer 中隨機抽取一批資料作為作為訓練集。然後將 preObservation 輸入到神經網路得到所有行為的 reward(Q)作為預測值:

NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
        .mul(actionInput.singletonOrThrow())
        .sum(new int[]{1}));

postObservation 同樣會輸入到神經網路,根據馬爾科夫決策過程以及貝爾曼價值函式計算出所有行為的 reward(targetQ)作為真實值:

// 將 postInput 輸入到神經網路中得到 targetQReward 是 (batchsize,2) 的矩陣。根據 Q-learning 的演算法,每一次的 targetQ 需要根據當前環境是否結束算出不同的值,因此需要將每一個 step 的 targetQ 單獨算出後再將 targetQ 堆積成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length]; 
for (int i = 0; i < batchSteps.length; i++) {
    if (batchSteps[i].isTerminal()) {
        targetQValue[i] = batchSteps[i].getReward();
    } else {
        targetQValue[i] = targetQReward.singletonOrThrow().get(i)
                .max()
                .mul(rewardDiscount)
                .add(rewardInput.singletonOrThrow().get(i));
    }
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

在訓練結束時,計算 Q 和 targetQ 的損失值,並在 CNN 中更新權重。

卷積神經網路模型(CNN)

我們採用了採用了 3 個卷積層,4 個 relu 啟用函式以及 2 個全連線層的神經網路架構。

layer input shape output shape
conv2d (batchSize, 4, 80, 80) (batchSize,4,20,20)
conv2d (batchSize, 4, 20 ,20) (batchSize, 32, 9, 9)
conv2d (batchSize, 32, 9, 9) (batchSize, 64, 7, 7)
linear (batchSize, 3136) (batchSize, 512)
linear (batchSize, 512) (batchSize, 2)

訓練過程

DJL 的 RL 庫中提供了非常方便的用於實現強化學習的介面:(RlEnv, RlAgent, ReplayBuffer)。

  • 實現 RlAgent 介面即可構建一個可以進行訓練的智慧體。
  • 在現有的遊戲環境中實現 RlEnv 介面即可生成訓練所需的資料。
  • 建立 ReplayBuffer 可以儲存並動態更新訓練資料。

在實現這些介面後,只需要呼叫 step 方法:

RlEnv.step(action, training);

這個方法會將 RlAgent 決策出的動作輸入到遊戲環境中獲得反饋。我們可以在 RlEnv 中提供的 runEnviroment 方法中呼叫 step 方法,然後只需要重複執行 runEnvironment 方法,即可不斷地生成用於訓練的資料。

public Step[] runEnvironment(RlAgent agent, boolean training) {
    // run the game
    NDList action = agent.chooseAction(this, training);
    step(action, training);
    if (training) {
        batchSteps = this.getBatch();
    }
    return batchSteps;
}

我們將 ReplayBuffer 可儲存的 step 數量設定為 50000,在 observe 週期我們會先向 replayBuffer 中儲存 1000 個使用隨機動作生成的 step,這樣可以使智慧體更快地從隨機動作中學習。

在 explore 和 training 週期,神經網路會隨機從 replayBuffer 中生成訓練集並將它們輸入到模型中訓練。我們使用 Adam 優化器和 MSE 損失函式迭代神經網路。

神經網路輸入預處理

首先將影像大小 resize 成 80x80 並轉為灰度圖,這有助於在不丟失資訊的情況下提高訓練速度。

public static NDArray imgPreprocess(BufferedImage observation) {
    return NDImageUtils.toTensor(
            NDImageUtils.resize(
                    ImageFactory.getInstance().fromImage(observation)
                    .toNDArray(NDManager.newBaseManager(),
                     Image.Flag.GRAYSCALE) ,80,80));
}

然後我們把連續的四幀影像作為一個輸入,為了獲得連續四幀的連續影像,我們維護了一個全域性的影像佇列儲存遊戲執行緒中的影像,每一次動作後替換掉最舊的一幀,然後把佇列裡的影像 stack 成一個單獨的 NDArray。

public NDList createObservation(BufferedImage currentImg) {
    NDArray observation = GameUtil.imgPreprocess(currentImg);
    if (imgQueue.isEmpty()) {
        for (int i = 0; i < 4; i++) {
            imgQueue.offer(observation);
        }
        return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
    } else {
        imgQueue.remove();
        imgQueue.offer(observation);
        NDArray[] buf = new NDArray[4];
        int i = 0;
        for (NDArray nd : imgQueue) {
            buf[i++] = nd;
        }
        return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
    }
}

一旦以上部分完成,我們就可以開始訓練了。訓練優化為了獲得最佳的訓練效能,我們關閉了 GUI 以加快樣本生成速度。並使用 Java 多執行緒將訓練迴圈和樣本生成迴圈分別在不同的執行緒中執行。

List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
    callables.add(new TrainerCallable(model, agent));
}

總結

這個模型在 NVIDIA T4 GPU 訓練了大概 4 個小時,更新了 300 萬步。訓練後的小鳥已經可以完全自主控制動作靈活穿梭與管道之間。訓練後的模型也同樣上傳到了倉庫中供您測試。在此專案中 DJL 提供了強大的訓練 API 以及模型庫支援,使得在 Java 開發過程中得心應手。

本專案完整程式碼:https://github.com/kingyuluk/RL-FlappyBird

相關文章