如何在Python中使用Java類? - codecentric

banq發表於2021-11-21

讓 Java 和 Python 攜手合作非常容易,這在開發原型時尤其有價值。
我們從一個實現 Snake 遊戲邏輯的 Java 程式開始:場上總有一塊食物。每當蛇到達食物時,它就會生長並出現新的食物。如果蛇咬自己或咬牆,遊戲結束。
我們的目標是訓練一個神經網路來控制蛇,讓蛇在犯錯和遊戲結束之前吃掉儘可能多的食物。首先,我們需要一個代表遊戲當前狀態的張量。它充當我們神經網路的輸入,以便網路可以使用它來預測下一步要採取的最佳步驟。為了讓這個例子簡單,我們的張量只是一個包含七個元素的向量,可以是 1 或 0:前四個表示食物是在蛇的右邊、左邊、前面還是後面,接下來的三個條目表示如果蛇頭的左邊、前面和右邊的田地都被一堵牆或蛇的尾巴擋住了。

我們示例的完整原始碼可在 GitHub 上找到。

使用JPype匯入Java類即可:

import jpype
import jpype.imports
from jpype.types import *
 
# launch the JVM
jpype.startJVM(classpath=['../target/autosnake-1.0-SNAPSHOT.jar'])
 
# import the Java module
from me.schawe.autosnake import SnakeLogic
 
# construct an object of the `SnakeLogic` class ...
width, height = 10, 10
snake_logic = SnakeLogic(width, height)
 
# ... and call a method on it
print(snake_logic.trainingState())


JPype 在與 Python 直譯器相同的程式中啟動 JVM,並讓它們使用 Java 本機介面 (JNI) 進行通訊。
其他選項:
  • Jython直接在 JVM 中執行 Python 直譯器,這樣 Python 和 Java 就可以非常高效地使用相同的資料結構。但這對使用原生 Python 庫有一些缺點——因為我們將使用numpy和tensorflow,這對我們來說不是一個選擇。
  • Py4J處於頻譜的另一側。它在 Java 程式碼中啟動一個套接字,它可以透過它與 Python 程式進行通訊。優點是任意數量的 Python 程式可以連線到一個長時間執行的 Java 程式——或者相反,一個 Python 程式可以連線到多個 JVM,甚至透過網路。缺點是套接字通訊的開銷較大。

 

在 Java 中載入模型
使用deeplearning4j將訓練好的模型載入到 Java 中……

// https://deeplearning4j.konduit.ai/deeplearning4j/how-to-guides/keras-import
public class Autopilot {
    ComputationGraph model;
 
    public Autopilot(String pathToModel) {
        try {
            model = KerasModelImport.importKerasModelAndWeights(pathToModel, false);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
 
    // infer the next move from the given state
    public int nextMove(boolean[] state) {
        INDArray input = Nd4j.create(state).reshape(1, state.length);
        INDArray output = model.output(input)[0];
 
        int action = output.ravel().argMax().getInt(0);
 
        return action;
    }
}
呼叫:
public class SnakeLogic {
    Autopilot autopilot = new Autopilot("path/to/model.h5");
 
    public void update() {
        int action = autopilot.nextMove(trainingState());
        turnRelative(action);
 
        // rest of the update omitted
    }
 
    // further methods omitted
}

 

相關文章