TensorFlow Lite for Android 初探(附demo)

Tony沈哲發表於2018-11-08

一. TensorFlow Lite

TensorFlow Lite介紹.jpeg

TensorFlow Lite特性.jpeg

TensorFlow Lite使用.jpeg

TensorFlow Lite 是用於移動裝置和嵌入式裝置的輕量級解決方案。TensorFlow Lite 支援 Android、iOS 甚至樹莓派等多種平臺。

我們知道大多數的 AI 是在雲端運算的,但是在移動端使用 AI 具有無網路延遲、響應更加及時、資料隱私等特性。

對於離線的場合,雲端的 AI 就無法使用了,而此時可以在移動裝置中使用 TensorFlow Lite。

二. tflite 格式

TensorFlow 生成的模型是無法直接給移動端使用的,需要離線轉換成.tflite檔案格式。

tflite 儲存格式是 flatbuffers。

FlatBuffers 是由Google開源的一個免費軟體庫,用於實現序列化格式。它類似於Protocol Buffers、Thrift、Apache Avro。

因此,如果要給移動端使用的話,必須把 TensorFlow 訓練好的 protobuf 模型檔案轉換成 FlatBuffers 格式。官方提供了 toco 來實現模型格式的轉換。

三. 常用的 Java API

TensorFlow Lite 提供了 C ++ 和 Java 兩種型別的 API。無論哪種 API 都需要載入模型和執行模型。

而 TensorFlow Lite 的 Java API 使用了 Interpreter 類(直譯器)來完成載入模型和執行模型的任務。後面的例子會看到如何使用 Interpreter。

四. TensorFlow Lite + mnist 資料集實現識別手寫數字

mnist 是手寫數字圖片資料集,包含60000張訓練樣本和10000張測試樣本。 測試集也是同樣比例的手寫數字資料。每張圖片有28x28個畫素點構成,每個畫素點用一個灰度值表示,這裡是將28x28的畫素展開為一個一維的行向量(每行784個值)。

mnist 資料集獲取地址:yann.lecun.com/exdb/mnist/

下面的 demo 中已經包含了 mnist.tflite 模型檔案。(如果沒有的話,需要自己訓練儲存成pb檔案,再轉換成tflite 格式)

對於一個識別類,首先需要初始化 TensorFlow Lite 直譯器,以及輸入、輸出。

    // The tensorflow lite file
    private lateinit var tflite: Interpreter

    // Input byte buffer
    private lateinit var inputBuffer: ByteBuffer

    // Output array [batch_size, 10]
    private lateinit var mnistOutput: Array<FloatArray>

    init {

        try {
            tflite = Interpreter(loadModelFile(activity))

            inputBuffer = ByteBuffer.allocateDirect(
                    BYTE_SIZE_OF_FLOAT * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE)
            inputBuffer.order(ByteOrder.nativeOrder())
            mnistOutput = Array(DIM_BATCH_SIZE) { FloatArray(NUMBER_LENGTH) }
            Log.d(TAG, "Created a Tensorflow Lite MNIST Classifier.")
        } catch (e: IOException) {
            Log.e(TAG, "IOException loading the tflite file failed.")
        }

    }
複製程式碼

從 asserts 檔案中載入 mnist.tflite 模型:

    /**
     * Load the model file from the assets folder
     */
    @Throws(IOException::class)
    private fun loadModelFile(activity: Activity): MappedByteBuffer {

        val fileDescriptor = activity.assets.openFd(MODEL_PATH)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }
複製程式碼

真正識別手寫數字是在 classify() 方法:

val digit = mnistClassifier.classify(Bitmap.createScaledBitmap(paintView.bitmap, PIXEL_WIDTH, PIXEL_WIDTH, false))
複製程式碼

classify() 方法包含了預處理用於初始化 inputBuffer、執行 mnist 模型、識別出數字。

    /**
     * Classifies the number with the mnist model.
     *
     * @param bitmap
     * @return the identified number
     */
    fun classify(bitmap: Bitmap): Int {

        if (tflite == null) {
            Log.e(TAG, "Image classifier has not been initialized; Skipped.")
        }

        preProcess(bitmap)
        runModel()
        return postProcess()
    }

    /**
     * Converts it into the Byte Buffer to feed into the model
     *
     * @param bitmap
     */
    private fun preProcess(bitmap: Bitmap?) {

        if (bitmap == null || inputBuffer == null) {
            return
        }

        // Reset the image data
        inputBuffer.rewind()

        val width = bitmap.width
        val height = bitmap.height

        // The bitmap shape should be 28 x 28
        val pixels = IntArray(width * height)
        bitmap.getPixels(pixels, 0, width, 0, 0, width, height)

        for (i in pixels.indices) {
            // Set 0 for white and 255 for black pixels
            val pixel = pixels[i]
            // The color of the input is black so the blue channel will be 0xFF.
            val channel = pixel and 0xff
            inputBuffer.putFloat((0xff - channel).toFloat())
        }
    }

    /**
     * Run the TFLite model
     */
    private fun runModel() = tflite.run(inputBuffer, mnistOutput)

    /**
     * Go through the output and find the number that was identified.
     *
     * @return the number that was identified (returns -1 if one wasn't found)
     */
    private fun postProcess(): Int {

        for (i in 0 until mnistOutput[0].size) {
            val value = mnistOutput[0][i]
            if (value == 1f) {
                return i
            }
        }

        return -1
    }
複製程式碼

對於 Android 有一個地方需要注意,必須在 app 模組的 build.gradle 中新增如下的語句,否則無法載入模型。

android {
    ......
    aaptOptions {
        noCompress "tflite"
    }
}
複製程式碼

demo 執行效果如下:

識別手寫數字5.png

識別手寫數字7.png

五. 總結

本文只是 TF Lite 的初探,很多細節並沒有詳細闡述。應該會在未來的文章中詳細介紹。

本文 demo 的 github 地址:github.com/fengzhizi71…

當然,也可以跑一下官方的例子:github.com/tensorflow/…


Java與Android技術棧:每週更新推送原創技術文章,歡迎掃描下方的公眾號二維碼並關注,期待與您的共同成長和進步。

TensorFlow Lite for Android 初探(附demo)

相關文章