用 Java 做個“你畫手機猜”的小遊戲

削微寒發表於2020-09-24

本文適合有 Java 基礎的人群

作者:DJL-Lanking

HelloGitHub 推出的《講解開源專案》系列。有幸邀請到了亞馬遜 + Apache 的工程師:Lanking( https://github.com/lanking520 ),為我們講解 DJL —— 完全由 Java 構建的深度學習平臺,本文為系列的第三篇。

一、前言

在 2018 年時,Google 推出了《猜畫小歌》應用:玩家可以直接與AI進行你畫我猜的遊戲。通過畫出一個房子或者一個貓,AI 會推斷出各種物品被畫出的概率。它的實現得益於深度學習模型在其中的應用,通過深度神經網路的歸納,曾經令人頭疼的繪畫識別也變得易如反掌。現如今,只要使用一個簡單的圖片分類模型,我們便可以輕鬆的實現繪畫識別。試試看這個線上塗鴉小遊戲吧:

線上塗鴉小遊戲:https://djl.ai/website/demo.html#doodle

在當時,大部分機器學習計算任務仍舊需要依託網路在雲端進行。隨著算力的不斷增進,機器學習任務已經可以直接在邊緣裝置部署,包括各類執行安卓系統的智慧手機。但是,由於安卓本身主要是用 Java ,部署基於 Python 的各類深度學習模型變成了一個難題。為了解決這個問題,AWS 開發並開源了 DeepJavaLibrary (DJL),一個為 Java 量身定製的深度學習框架。

在這個文章中,我們將嘗試通過 PyTorch 預訓練模型在在安卓平臺構建一個塗鴉繪畫的應用。由於總程式碼量會比較多,我們這次會挑重點把最關鍵的程式碼完成。你可以後續參考我們完整的專案進行構建。

塗鴉應用完整程式碼:https://github.com/aws-samples/djl-demo/tree/master/android

二、環境配置

為了相容 DJL 需求的 Java 功能,這個專案需要 Android API 26 及以上的版本。你可以參考我們案例配置來節約一些時間,下面是這個專案需要的依賴項:

案例 gradle: https://github.com/aws-samples/djl-demo/blob/master/android/quickdraw_recognition/build.gradle

dependencies {
    implementation 'androidx.appcompat:appcompat:1.2.0'
    implementation 'ai.djl:api:0.7.0'
    implementation 'ai.djl.android:core:0.7.0'
    runtimeOnly 'ai.djl.pytorch:pytorch-engine:0.7.0'
    runtimeOnly 'ai.djl.android:pytorch-native:0.7.0'
}

我們將使用 DJL 提供的 API 以及 PyTorch 包。

三、構建應用

3.1 第一步:建立 Layout

我們可以先建立一個 View class 以及 layout(如下圖)來構建安卓的前端顯示介面。

如上圖所示,你可以在主介面建立兩個 View 目標。PaintView 是用來讓使用者畫畫的,在右下角 ImageView 是用來展示用於深度學習推理的影像。同時我們預留一個按鈕來進行畫板的清空操作。

3.2 第二步: 應對繪畫動作

在安卓裝置上,你可以自定義安卓的觸控事件響應來應對使用者的各種觸控操作。在我們的情況下,我們需要定義下面三種時間響應:

  • touchStart:感應觸碰時觸發
  • touchMove:當使用者在螢幕上移動手指時觸發
  • touchUp:當使用者抬起手指時觸發

與此同時,我們用 paths 來儲存使用者在畫板所繪製的路徑。現在我們看一下實現程式碼。

3.2.1 重寫 OnTouchEventOnDraw 方法

現在我們重寫 onTouchEvent 來應對各種響應:

@Override
public boolean onTouchEvent(MotionEvent event) {
    float x = event.getX();
    float y = event.getY();

    switch (event.getAction()) {
        case MotionEvent.ACTION_DOWN :
            touchStart(x, y);
            invalidate();
            break;
        case MotionEvent.ACTION_MOVE :
            touchMove(x, y);
            invalidate();
            break;
        case MotionEvent.ACTION_UP :
            touchUp();
            runInference();
            invalidate();
            break;
    }

    return true;
}

如上面程式碼所示,你可以新增一個 runInference 方法在 MotionEvent.ACTION_UP 事件響應上。這個方法是用來在使用者繪製完後對結果進行推理。在之後的幾步中,我們會講解它的具體實現。

我們同樣需要重寫 onDraw 方法來展示使用者繪製的影像:

@Override
protected void onDraw(Canvas canvas) {
    canvas.save();
    this.canvas.drawColor(DEFAULT_BG_COLOR);

    for (Path path : paths) {
        paint.setColor(DEFAULT_PAINT_COLOR);
        paint.setStrokeWidth(BRUSH_SIZE);
        this.canvas.drawPath(path, paint);
    }
    canvas.drawBitmap(bitmap, 0, 0, bitmapPaint);
    canvas.restore();
}

真正的影像會儲存在一個 Bitmap 上。

3.2.2 操作開始(touchStart)

當使用者觸碰行為開始時,下面的程式碼會建立一個新的路徑同時記錄路徑中每一個點在螢幕上的座標。

private void touchStart(float x, float y) {
    path = new Path();
    paths.add(path);
    path.reset();
    path.moveTo(x, y);
    this.x = x;
    this.y = y;
}

3.2.3 手指移動(touchMove)

在手指移動中,我們會持續記錄座標點然後將它們構成一個 quadratic bezier. 通過一定的誤差閥值來動態優化使用者的繪畫動作。只有差別超出誤差範圍內的動作才會被記錄下來。

quadratic bezier 文件: https://developer.android.com/reference/android/graphics/Path

private void touchMove(float x, float y) {
    if (x < 0 || x > getWidth() || y < 0 || y > getHeight()) {
        return;
    }
    float dx = Math.abs(x - this.x);
    float dy = Math.abs(y - this.y);

    if (dx >= TOUCH_TOLERANCE || dy >= TOUCH_TOLERANCE) {
        path.quadTo(this.x, this.y, (x + this.x) / 2, (y + this.y) / 2);
        this.x = x;
        this.y = y;
    }
}

3.2.4 操作結束(touchUp)

當觸控操作結束後,下面的程式碼會繪製一個路徑同時計算最小長方形目標框。

private void touchUp() {
    path.lineTo(this.x, this.y);
    maxBound.add(new Path(path));
}

3.3 第三步:開始推理

為了在安卓裝置上進行推理任務,我們需要完成下面幾個任務:

  • 從 URL 讀取模型
  • 構建前處理和後處理過程
  • 從 PaintView 進行推理任務

為了完成以下目標,我們嘗試構建一個 DoodleModel class。在這一步,我們將介紹一些完成這些任務的關鍵步驟。

3.3.1 讀取模型

DJL 內建了一套模型管理系統。開發者可以自定義儲存模型的資料夾。

File dir = getFilesDir();
System.setProperty("DJL_CACHE_DIR", dir.getAbsolutePath());

通過更改 DJL_CACHE_DIR 屬性,模型會被存入相應路徑下。

下一步可以通過定義 Criteria 從指定 URL 處下載模型。下載的 zip 檔案內包含:

  • doodle_mobilenet.pt:PyTorch 模型
  • synset.txt:包含分類任務中所有類別的名稱
Criteria<Image, Classifications> criteria =
            Criteria.builder()
                    .setTypes(Image.class, Classifications.class)
                    .optModelUrls("https://djl-ai.s3.amazonaws.com/resources/demo/pytorch/doodle_mobilenet.zip")
                    .optTranslator(translator)
                    .build();
return ModelZoo.loadModel(criteria);

上述程式碼同時定義了 translator,它會被用來做圖片的前處理和後處理。

最後,如下述程式碼建立一個 Model 並用它建立一個 Predictor

@Override
protected Boolean doInBackground(Void... params) {
    try {
        model = DoodleModel.loadModel();
        predictor = model.newPredictor();
        return true;
    } catch (IOException | ModelException e) {
        Log.e("DoodleDraw", null, e);
    }
    return false;
}

更多關於模型載入的資訊,請參閱如何載入模型。

DJL 模型載入文件:http://docs.djl.ai/docs/load_model.html

3.3.2 用 Translator 定義前處理和後處理

在 DJL 中,我們定義了 Translator 介面進行前處理和後處理。在 DoodleModel 中我們定義了 ImageClassificationTranslator 來實現 Translator:

ImageClassificationTranslator.builder()
    .addTransform(new ToTensor())
    .optFlag(Image.Flag.GRAYSCALE)
    .optApplySoftmax(true).build());

下面我們詳細闡述 translator 所定義的前處理和後處理如何被用在模型的推理步驟中。當你建立 translator 時,內部程式會自動載入 synset.txt 檔案得到做分類任務時所有類別的名稱。當模型的 predict() 方法被呼叫時,內部程式會先執行所對應的 translator 的前處理步驟,而後執行實際推理步驟,最後執行 translator 的後處理步驟。對於前處理,我們會將 Image 轉化 NDArray,用於作為模型推理過程的輸入。對於後處理,我們對推理輸出的結果(NDArray)進行 softmax 操作。最終返回結果為 Classifications 的一個例項。

自定義 Translator 案例:http://docs.djl.ai/jupyter/pytorch/load_your_own_pytorch_bert.html

3.3.3 用 PaintView 進行推理任務

最後,我們來實現之前定義好的 runInference 方法。

public void runInference() {
    // 拷貝影像
    Bitmap bmp = Bitmap.createBitmap(bitmap);
    // 縮放影像
    bmp = Bitmap.createScaledBitmap(bmp, 64, 64, true);
   // 執行推理任務
    Classifications classifications = model.predict(bmp);
   // 展示輸入的影像
    Bitmap present = Bitmap.createScaledBitmap(bmp, imageView.getWidth(), imageView.getHeight(), true);
    imageView.setImageBitmap(present);
   // 展示輸出的影像
   if (messageToast != null) {
        messageToast.cancel();
    }
    messageToast = Toast.makeText(getContext(), classifications.toString(), Toast.LENGTH_SHORT);
    messageToast.show();
}

這將會建立一個 Toast 彈出頁面用於展示結果,示例如下:

恭喜你!我們完成了一個塗鴉識別小程式!

3.4 可選優化:輸入裁剪

為了得到更高的模型推理準確度,你可以通過擷取影像來去除無意義的邊框部分。

上面右側的圖片會比左邊的圖片有更好的推理結果,因為它所包含的空白邊框更少。你可以通過 Bound 類來尋找圖片的有效邊界,即能把圖中所有白色畫素點覆蓋的最小矩形。在得到 x 軸最左座標,y 軸最上座標,以及矩形高度和寬度後,就可以用這些資訊擷取出我們想要的圖形(如右圖所示)實現程式碼如下:

RectF bound = maxBound.getBound();
int x = (int) bound.left;
int y = (int) bound.top;
int width = (int) Math.ceil(bound.width());
int height = (int) Math.ceil(bound.height());
// 擷取部分影像
Bitmap bmp = Bitmap.createBitmap(bitmap, x, y, width, height);

恭喜你!現在你就掌握了全部教程內容!期待看到你建立的第一個 DoodleDraw 安卓遊戲!

最後,可以在GitHub找到本教程的完整案例程式碼。

塗鴉應用完整程式碼:https://github.com/aws-samples/djl-demo/tree/master/android

關於 DJL

Deep Java Library (DJL) 是一個基於 Java 的深度學習框架,同時支援訓練以及推理。 DJL 博取眾長,構建在多個深度學習框架之上 (TenserFlow、PyTorch、MXNet 等) 也同時具備多個框架的優良特性。你可以輕鬆使用 DJL 來進行訓練然後部署你的模型。

它同時擁有著強大的模型庫支援:只需一行便可以輕鬆讀取各種預訓練的模型。現在 DJL 的模型庫同時支援高達 70 個來自 GluonCV、 HuggingFace、TorchHub 以及 Keras 的模型。

專案地址:https://github.com/awslabs/djl/


關注 HelloGitHub 公眾號

相關文章