如何將模型部署到安卓移動端,這裡有一份簡單教程

路雪發表於2018-07-18

截至 2018 年,全球活躍的安卓裝置已經超過了 20 億部。安卓手機的迅速普及在很大程度上得益於各種各樣的智慧應用,從地圖到圖片編輯器無所不有。隨著深度學習技術的興起,移動應用註定會變得更加智慧。深度學習加持下的下一代移動應用將專門為你學習和定製功能。微軟的「SwiftKey」就是一個很好的例子,它能夠透過學習你常用的單詞和短語來幫助你更快地打字。

計算機視覺自然語言處理語音識別以及語音合成等技術可以極大地提高移動應用程式各個方面的使用者體驗。幸運的是,人們現在已經開發出了大量工具,用於簡化在移動應用中部署和管理深度學習模型的過程。在本文中,作者將向大家介紹如何使用 TensorFlow Mobile 將 Pytorch 和 Keras 模型部署到移動裝置上。

使用 TensorFlow Mobile 將模型部署到安卓裝置上包括三個步驟:

  • 將訓練好的模型轉換成 TensorFlow 格式;

  • 向安卓應用新增 TensorFlow Mobile 依賴項;

  • 編寫相關的 Java 程式碼,在你的應用中使用 TensorFlow 模型執行推斷。

在本文中,我將帶你熟悉以上的整個流程,最終完成一個嵌入影像識別功能的安卓應用。

環境設定

在本教程中,我們將使用 Pytorch 和 Keras,選擇你偏好的機器學習框架,並按照說明進行操作。因此,你的環境設定取決於你選擇的框架。

第一步,安裝 TensorFlow:

pip3 install tensorflow

如果你是 PyTorch 開發者,請確保你已經安裝了最新版本的 PyTorch。關於安裝 PyTorch 的說明,請查閱我早前編寫的這篇文章(https://heartbeat.fritz.ai/basics-of-image-classification-with-pytorch-2f8973c51864)。

如果你是一名 Keras 開發者,你可以使用下面的命令安裝相關開發環境:

pip3 install keras
pip3 install h5py

Android Studio(精簡版 3.0)

https://developer.android.com/studio

將 PyTorch 模型轉換為 Keras 模型

本節僅針對於 PyTorch 開發者。如果你使用的是 Keras 框架,你可以直接跳到「將 Keras 模型轉換為 TensorFlow 模型」這一節。

我們需要做的第一件事就是將 PyTorch 模型的引數轉化為其在 Keras 框架下等價的引數。為了簡化這個過程,我編寫了一個指令碼來自動化地進行這個轉換工作。在這篇教程中,我將使用 Squeezenet,這是一種準確率還不錯且規模非常小的移動架構。你可以透過這個連結下載預訓練好的模型(大小僅僅只有 5mb!):https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth。

在轉換權重之前,我們需要在 PyTorch 和 Keras 中定義 Squeezenet 模型。

在兩個框架中都定義 Squeezenet,然後使用下面的方法將 PyTorch 框架的權重遷移到 Keras 框架中。

建立一個 convert.py 檔案,引入下面的程式碼,並且執行指令碼。

import torch
import torch.nn as nn
from torch.autograd import Variable
import keras.backend as K
from keras.models import *
from keras.layers import *

import torch
from torchvision.models import squeezenet1_1


class PytorchToKeras(object):
    def __init__(self,pModel,kModel):
        super(PytorchToKeras,self)
        self.__source_layers = []
        self.__target_layers = []
        self.pModel = pModel
        self.kModel = kModel

        K.set_learning_phase(0)

    def __retrieve_k_layers(self):

        for i,layer in enumerate(self.kModel.layers):
            if len(layer.weights) > 0:
                self.__target_layers.append(i)

    def __retrieve_p_layers(self,input_size):

        input = torch.randn(input_size)

        input = Variable(input.unsqueeze(0))

        hooks = []

        def add_hooks(module):

            def hook(module, input, output):
                if hasattr(module,"weight"):
                    self.__source_layers.append(module)

            if not isinstance(module, nn.ModuleList) and not isinstance(module,nn.Sequential) and module != self.pModel:
                hooks.append(module.register_forward_hook(hook))

        self.pModel.apply(add_hooks)


        self.pModel(input)
        for hook in hooks:
            hook.remove()

    def convert(self,input_size):
        self.__retrieve_k_layers()
        self.__retrieve_p_layers(input_size)

        for i,(source_layer,target_layer) in enumerate(zip(self.__source_layers,self.__target_layers)):

            weight_size = len(source_layer.weight.data.size())

            transpose_dims = []

            for i in range(weight_size):
                transpose_dims.append(weight_size - i - 1)

            self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy().transpose(transpose_dims), source_layer.bias.data.numpy()])

    def save_model(self,output_file):
        self.kModel.save(output_file)
    def save_weights(self,output_file):
        self.kModel.save_weights(output_file)



"""
We explicitly redefine the Squeezent architecture since Keras has no predefined Squeezent
"""

def squeezenet_fire_module(input, input_channel_small=16, input_channel_large=64):

    channel_axis = 3

    input = Conv2D(input_channel_small, (1,1), padding="valid" )(input)
    input = Activation("relu")(input)

    input_branch_1 = Conv2D(input_channel_large, (1,1), padding="valid" )(input)
    input_branch_1 = Activation("relu")(input_branch_1)

    input_branch_2 = Conv2D(input_channel_large, (3, 3), padding="same")(input)
    input_branch_2 = Activation("relu")(input_branch_2)

    input = concatenate([input_branch_1, input_branch_2], axis=channel_axis)

    return input


def SqueezeNet(input_shape=(224,224,3)):



    image_input = Input(shape=input_shape)


    network = Conv2D(64, (3,3), strides=(2,2), padding="valid")(image_input)
    network = Activation("relu")(network)
    network = MaxPool2D( pool_size=(3,3) , strides=(2,2))(network)

    network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
    network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
    network = MaxPool2D(pool_size=(3,3), strides=(2,2))(network)

    network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
    network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
    network = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(network)

    network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
    network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
    network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)
    network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)

    #Remove layers like Dropout and BatchNormalization, they are only needed in training
    #network = Dropout(0.5)(network)

    network = Conv2D(1000, kernel_size=(1,1), padding="valid", name="last_conv")(network)
    network = Activation("relu")(network)

    network = GlobalAvgPool2D()(network)
    network = Activation("softmax",name="output")(network)


    input_image = image_input
    model = Model(inputs=input_image, outputs=network)

    return model


keras_model = SqueezeNet()


#Lucky for us, PyTorch includes a predefined Squeezenet
pytorch_model = squeezenet1_1()

#Load the pretrained model
pytorch_model.load_state_dict(torch.load("squeezenet.pth"))

#Time to transfer weights

converter = PytorchToKeras(pytorch_model,keras_model)
converter.convert((3,224,224))

#Save the weights of the converted keras model for later use
converter.save_weights("squeezenet.h5")

在完成了上述權重轉換工作後,你現在只需將 Keras 模型儲存為 squeezenet.h5。此時,我們可以將 PyTorch 模型拋在腦後,繼續進行我們的下一步工作。

將 Keras 模型轉化為 TensorFlow 模型

此時,你已經有了一個從 PyTorch 模型轉換而來的 Keras 模型,或者直接使用 Keras 訓練得到的模型。你可以透過這裡的連結(https://github.com/OlafenwaMoses/ImageAI/releases/download/1.0/squeezenetweightstfdimorderingtfkernels.h5)下載預訓練好的 Keras Squeezenet 模型。下一步,將整個模型架構和權重轉換成一個可用於實際生產的 TensorFlow 模型。

建立一個新的 ConvertToTensorflow.py 檔案,新增以下程式碼。

from keras.models import Model
from keras.layers import *
import os
import tensorflow as tf


def keras_to_tensorflow(keras_model, output_dir, model_name,out_prefix="output_", log_tensorboard=True):

    if os.path.exists(output_dir) == False:
        os.mkdir(output_dir)

    out_nodes = []

    for i in range(len(keras_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(keras_model.output[i], out_prefix + str(i + 1))

    sess = K.get_session()

    from tensorflow.python.framework import graph_util, graph_io

    init_graph = sess.graph.as_graph_def()

    main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)

    graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)

    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard

        import_pb_to_tensorboard.import_to_tensorboard(
            os.path.join(output_dir, model_name),
            output_dir)


"""
We explicitly redefine the Squeezent architecture since Keras has no predefined Squeezenet
"""

def squeezenet_fire_module(input, input_channel_small=16, input_channel_large=64):

    channel_axis = 3

    input = Conv2D(input_channel_small, (1,1), padding="valid" )(input)
    input = Activation("relu")(input)

    input_branch_1 = Conv2D(input_channel_large, (1,1), padding="valid" )(input)
    input_branch_1 = Activation("relu")(input_branch_1)

    input_branch_2 = Conv2D(input_channel_large, (3, 3), padding="same")(input)
    input_branch_2 = Activation("relu")(input_branch_2)

    input = concatenate([input_branch_1, input_branch_2], axis=channel_axis)

    return input


def SqueezeNet(input_shape=(224,224,3)):



    image_input = Input(shape=input_shape)


    network = Conv2D(64, (3,3), strides=(2,2), padding="valid")(image_input)
    network = Activation("relu")(network)
    network = MaxPool2D( pool_size=(3,3) , strides=(2,2))(network)

    network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
    network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
    network = MaxPool2D(pool_size=(3,3), strides=(2,2))(network)

    network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
    network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
    network = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(network)

    network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
    network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
    network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)
    network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)

    #Remove layers like Dropout and BatchNormalization, they are only needed in training
    #network = Dropout(0.5)(network)

    network = Conv2D(1000, kernel_size=(1,1), padding="valid", name="last_conv")(network)
    network = Activation("relu")(network)

    network = GlobalAvgPool2D()(network)
    network = Activation("softmax",name="output")(network)


    input_image = image_input
    model = Model(inputs=input_image, outputs=network)

    return model


keras_model = SqueezeNet()

keras_model.load_weights("squeezenet.h5")


output_dir = os.path.join(os.getcwd(),"checkpoint")

keras_to_tensorflow(keras_model,output_dir=output_dir,model_name="squeezenet.pb")

print("MODEL SAVED")

上述程式碼將 squeezenet.pb 儲存到了我們的 output_dir 資料夾中。它還在同一個資料夾中建立了 TensorBoard 事件檔案。

為了對模型有一個更清晰的理解,你可以在 TensorBoard 中對其視覺化。

你需要開啟命令提示符,輸入:

tensorboard –logdir=output_dir_path

output_dir_path 即 output_dir 的路徑。

當你成功啟動 TensorBoard 後,你將看到下面的對話方塊,請求你開啟 COMPUTER_NAME:6006 的 url 連結。

如何將模型部署到安卓移動端,這裡有一份簡單教程

在你最喜歡的瀏覽器中輸入 URL 地址,會顯示出下圖所示的介面。

如何將模型部署到安卓移動端,這裡有一份簡單教程

請雙擊 IMPORT 來視覺化你的模型。

仔細檢視模型,並且請注意輸入和輸出節點的名字(架構中的第一個和最後一個節點)。

如果你像我在前面的程式碼中那樣命名你的網路層,那麼它們的名稱應該分別是 input_1 和 output_1。

至此,我們的模型已經完全準備就緒,可以進行部署了。

TensorFLow Mobile 新增到你的專案中

TensorFlow 有兩個移動程式庫——「TensorFlow Mobile」和「TensorFlow Lite」。Lite 版本是為極小規模的模型設計的,整個依賴項僅佔用大約 1Mb 的空間。Lite 中的模型也經過了更好的最佳化。最近,在安卓 8 及更高版本中,TensorFlow Lite 使用安卓神經網路 API 進行加速。然而,與「TensorFlow Mobile」不同,Lite 並不能直接用於生產,因為其中一些層的表現可能沒有如預期一樣好。此外,Windows 系統至今還不支援對 Lite 庫的編譯,以及將其模型轉換為本地格式。因此,在這篇教程中,我堅持使用 TensorFlow Mobile。

接下來,如果你沒有現有的安卓專案,請在 Android Studio 中建立一個。在你的 build.gradle 檔案中新增 TensorFlow Mobile 依賴。

implementation ‘org.tensorflow:tensorflow-android:+’

Android Studio 將向你提示同步 gradle(一種專案自動化構建開源工具)。點選 Sync Now,等待同步完成。

此時,你的環境就已經完全設定好了。

在移動 app 中執行推斷

在編寫程式碼進行實際推斷之前,你需要將轉換後的模型(squeezenet.pb)新增到應用程式的資原始檔夾中。在 Android Studio 中,右鍵點選你的專案,跳轉至「Add Folder」(新增資料夾)部分,並選擇「Assets Folder」(資原始檔夾)。這將在你的應用程式目錄中建立一個資原始檔夾。接下來,你需要將模型複製到資原始檔夾中。

你可以透過這個連結(https://github.com/johnolafenwa/Pytorch-Keras-ToAndroid/raw/master/android-sample/app/src/main/assets/labels.json)下載類標籤,並且將檔案複製到資原始檔夾中。

現在你的專案已經包含了進行影像分類所需的一切。

將一個新的 Java 類新增到專案的主程式包中,並將其命名為 ImageUtils,把下面的程式碼複製到其中。

package com.specpal.mobileai;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.os.Environment;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import org.json.*;

/**
 * Utility class for manipulating images.
 **/
public class ImageUtils {
    /**
     * Returns a transformation matrix from one reference frame into another.
     * Handles cropping (if maintaining aspect ratio is desired) and rotation.
     *
     * @param srcWidth Width of source frame.
     * @param srcHeight Height of source frame.
     * @param dstWidth Width of destination frame.
     * @param dstHeight Height of destination frame.
     * @param applyRotation Amount of rotation to apply from one frame to another.
     *  Must be a multiple of 90.
     * @param maintainAspectRatio If true, will ensure that scaling in x and y remains constant,
     * cropping the image if necessary.
     * @return The transformation fulfilling the desired requirements.
     */
    public static Matrix getTransformationMatrix(
            final int srcWidth,
            final int srcHeight,
            final int dstWidth,
            final int dstHeight,
            final int applyRotation,
            final boolean maintainAspectRatio) {
        final Matrix matrix = new Matrix();

        if (applyRotation != 0) {
            // Translate so center of image is at origin.
            matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f);

            // Rotate around origin.
            matrix.postRotate(applyRotation);
        }

        // Account for the already applied rotation, if any, and then determine how
        // much scaling is needed for each axis.
        final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0;

        final int inWidth = transpose ? srcHeight : srcWidth;
        final int inHeight = transpose ? srcWidth : srcHeight;

        // Apply scaling if necessary.
        if (inWidth != dstWidth || inHeight != dstHeight) {
            final float scaleFactorX = dstWidth / (float) inWidth;
            final float scaleFactorY = dstHeight / (float) inHeight;

            if (maintainAspectRatio) {
                // Scale by minimum factor so that dst is filled completely while
                // maintaining the aspect ratio. Some image may fall off the edge.
                final float scaleFactor = Math.max(scaleFactorX, scaleFactorY);
                matrix.postScale(scaleFactor, scaleFactor);
            } else {
                // Scale exactly to fill dst from src.
                matrix.postScale(scaleFactorX, scaleFactorY);
            }
        }

        if (applyRotation != 0) {
            // Translate back from origin centered reference to destination frame.
            matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f);
        }

        return matrix;
    }


    public static Bitmap processBitmap(Bitmap source,int size){

        int image_height = source.getHeight();
        int image_width = source.getWidth();

        Bitmap croppedBitmap = Bitmap.createBitmap(size, size, Bitmap.Config.ARGB_8888);

        Matrix frameToCropTransformations = getTransformationMatrix(image_width,image_height,size,size,0,false);
        Matrix cropToFrameTransformations = new Matrix();
        frameToCropTransformations.invert(cropToFrameTransformations);

        final Canvas canvas = new Canvas(croppedBitmap);
        canvas.drawBitmap(source, frameToCropTransformations, null);

        return croppedBitmap;


    }

    public static float[] normalizeBitmap(Bitmap source,int size,float mean,float std){

        float[] output = new float[size * size * 3];

        int[] intValues = new int[source.getHeight() * source.getWidth()];

        source.getPixels(intValues, 0, source.getWidth(), 0, 0, source.getWidth(), source.getHeight());
        for (int i = 0; i < intValues.length; ++i) {
            final int val = intValues[i];
            output[i * 3] = (((val >> 16) & 0xFF) - mean)/std;
            output[i * 3 + 1] = (((val >> 8) & 0xFF) - mean)/std;
            output[i * 3 + 2] = ((val & 0xFF) - mean)/std;
        }

        return output;

    }

    public static Object[] argmax(float[] array){


        int best = -1;
        float best_confidence = 0.0f;

        for(int i = 0;i < array.length;i++){

            float value = array[i];

            if (value > best_confidence){

                best_confidence = value;
                best = i;
            }
        }



        return new Object[]{best,best_confidence};


    }


    public static String getLabel( InputStream jsonStream,int index){
        String label = "";
        try {

            byte[] jsonData = new byte[jsonStream.available()];
            jsonStream.read(jsonData);
            jsonStream.close();

            String jsonString = new String(jsonData,"utf-8");

            JSONObject object = new JSONObject(jsonString);

            label = object.getString(String.valueOf(index));



        }
        catch (Exception e){


        }
        return label;
    }
}

如果你並不理解上面的大部分程式碼,那也沒事——這是 TensorFlow Mobile 核心庫中沒有實現的兩個標準函式。因此,藉助官方示例的幫助,我寫下了它們,以方便後續操作的便捷進行。

在你的主活動(main activity)中建立 ImageView 和 TextView。它們將被用於顯示影像和預測結果。

在主活動中,你需要載入 TensorFlow-inference 庫,並且初始化一些類變數。在 onCreat 方法前新增以下內容:

 //Load the tensorflow inference library
    static {
        System.loadLibrary("tensorflow_inference");
    }

    //PATH TO OUR MODEL FILE AND NAMES OF THE INPUT AND OUTPUT NODES
    private String MODEL_PATH = "file:///android_asset/squeezenet.pb";
    private String INPUT_NAME = "input_1";
    private String OUTPUT_NAME = "output_1";
    private TensorFlowInferenceInterface tf;

    //ARRAY TO HOLD THE PREDICTIONS AND FLOAT VALUES TO HOLD THE IMAGE DATA
    float[] PREDICTIONS = new float[1000];
    private float[] floatValues;
    private int[] INPUT_SIZE = {224,224,3};

    ImageView imageView;
    TextView resultView;
    Snackbar progressBar;

新增一個計算預測類的函式:

//FUNCTION TO COMPUTE THE MAXIMUM PREDICTION AND ITS CONFIDENCE
    public Object[] argmax(float[] array){


        int best = -1;
        float best_confidence = 0.0f;

        for(int i = 0;i < array.length;i++){

            float value = array[i];

            if (value > best_confidence){

                best_confidence = value;
                best = i;
            }
        }

        return new Object[]{best,best_confidence};


    }

新增接收影像點陣圖並對其進行推斷的函式:

public void predict(final Bitmap bitmap){


        //Runs inference in background thread
        new AsyncTask<Integer,Integer,Integer>(){

            @Override

            protected Integer doInBackground(Integer ...params){

                //Resize the image into 224 x 224
                Bitmap resized_image = ImageUtils.processBitmap(bitmap,224);

                //Normalize the pixels
                floatValues = ImageUtils.normalizeBitmap(resized_image,224,127.5f,1.0f);

                //Pass input into the tensorflow
                tf.feed(INPUT_NAME,floatValues,1,224,224,3);

                //compute predictions
                tf.run(new String[]{OUTPUT_NAME});

                //copy the output into the PREDICTIONS array
                tf.fetch(OUTPUT_NAME,PREDICTIONS);

                //Obtained highest prediction
                Object[] results = argmax(PREDICTIONS);


                int class_index = (Integer) results[0];
                float confidence = (Float) results[1];


                try{

                    final String conf = String.valueOf(confidence * 100).substring(0,5);

                    //Convert predicted class index into actual label name
                   final String label = ImageUtils.getLabel(getAssets().open("labels.json"),class_index);



                   //Display result on UI
                    runOnUiThread(new Runnable() {
                        @Override
                        public void run() {

                            progressBar.dismiss();
                            resultView.setText(label + " : " + conf + "%");

                        }
                    });

                }

                catch (Exception e){


                }


                return 0;
            }



        }.execute(0);

    }

上述程式碼在後臺執行緒中執行預測工作,並且將預測出的類和它的置信度寫入我們之前定義的 TextView 檔案中。

請注意,在主使用者介面(UI)執行緒上執行推斷可能會掛起應用程式。一般而言,我們總是在後臺執行緒執行推斷工作。

為了將本教程的重點放在影像識別的主題上,我簡單地使用了一張新增到資原始檔夾中的鳥的影像。在標準應用程式中,你應該編寫程式碼從檔案系統中載入圖片。

你可以向資原始檔夾新增任何你想要預測的影像。為了進行一次真實的預測,在下面的程式碼中,我們為一個按鈕新增了一個點選事件的監聽器。這個監聽器僅僅載入圖片並且呼叫預測函式。

 @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);


        Toolbar toolbar = (Toolbar) findViewById(R.id.toolbar);
        setSupportActionBar(toolbar);


        //initialize tensorflow with the AssetManager and the Model
        tf = new TensorFlowInferenceInterface(getAssets(),MODEL_PATH);

        imageView = (ImageView) findViewById(R.id.imageview);
        resultView = (TextView) findViewById(R.id.results);

        progressBar = Snackbar.make(imageView,"PROCESSING IMAGE",Snackbar.LENGTH_INDEFINITE);


        final FloatingActionButton predict = (FloatingActionButton) findViewById(R.id.predict);
        predict.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {


                try{

                    //READ THE IMAGE FROM ASSETS FOLDER
                    InputStream imageStream = getAssets().open("testimage.jpg");

                    Bitmap bitmap = BitmapFactory.decodeStream(imageStream);

                    imageView.setImageBitmap(bitmap);

                    progressBar.show();

                    predict(bitmap);
                }
                catch (Exception e){

                }

            }
        });
    }

現在,我們大功告成了!再仔細檢查一遍,以確保你正確地完成了每一步。如果一切正常,請點選「Build APK」(構建安卓安裝包)。

稍等片刻,你的安裝包就構建好了。你可以安裝 APK,執行該應用程式。

執行結果如下:

如何將模型部署到安卓移動端,這裡有一份簡單教程

要想獲得更令人興奮的體驗,你應該實現一些新功能,從安卓檔案系統載入影像,或者使用相機獲取影像,而不是使用資原始檔夾。

總結

移動端的深度學習最終將改變我們構建和使用 app 的方式。透過上面的程式碼片段,你可以很容易地將訓練好的 PyTorch 和 Keras 模型匯出到 TensorFlow 環境下。藉助於 TensorFlow Mobile 的強大功能,並且按照本文中介紹的步驟,你可以為自己的移動應用程式無縫注入優秀的人工智慧功能。

安卓專案的全部程式碼和模型轉換器可以在我的 GitHub 上(https://github.com/johnolafenwa/Pytorch-Keras-ToAndroid)獲得。如何將模型部署到安卓移動端,這裡有一份簡單教程

原文連結:https://heartbeat.fritz.ai/deploying-pytorch-and-keras-models-to-android-with-tensorflow-mobile-a16a1fb83f2

相關文章