基於 keras-js 快速實現瀏覽器內的 CNN 手寫數字識別

Starkwang發表於2018-01-26

在這篇文章中,我會快速地介紹如何使用 keras 訓練一個簡單的識別 MNIST(一個手寫數字資料集)的 CNN(卷積神經網路),並且把訓練好的網路應用到 web 瀏覽器內。

DEMO 地址:starkwang.github.io/keras-js-de…

基於 keras-js 快速實現瀏覽器內的 CNN 手寫數字識別


零、準備工作

首先需要給你的電腦安裝 keras,具體安裝的步驟請參考 keras 官方文件


一、快速入門

首先十分推薦閱讀 tensorflow 官方文件中的 MNIST For ML Beginners,這裡是極客學院的中文翻譯

MNIST 是一個很流行的入門級機器學習/計算機視覺資料集,它包含 0 - 9 的各種手寫數字圖片:

基於 keras-js 快速實現瀏覽器內的 CNN 手寫數字識別

每張圖片的尺寸均為 28 * 28,用一個 28 * 28 的二維陣列來表示,換句話說,每張圖片都是由 784 個畫素點組成,每個畫素點的值在 0 - 255 之間。

比如下面就是一個 "3" 的資料:

000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 038 043 105 255 253 253 253 253 253 174 006 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 043 139 224 226 252 253 252 252 252 252 252 252 158 014 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 178 252 252 252 252 253 252 252 252 252 252 252 252 059 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 109 252 252 230 132 133 132 132 189 252 252 252 252 059 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 004 029 029 024 000 000 000 000 014 226 252 252 172 007 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 085 243 252 252 144 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 088 189 252 252 252 014 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 091 212 247 252 252 252 204 009 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 032 125 193 193 193 253 252 252 252 238 102 028 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 045 222 252 252 252 252 253 252 252 252 177 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 045 223 253 253 253 253 255 253 253 253 253 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 031 123 052 044 044 044 044 143 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 015 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 086 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 005 075 009 000 000 000 000 000 000 098 242 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 061 183 252 029 000 000 000 000 018 092 239 252 252 243 065 000 000 000 000 000 000 000 000 
000 000 000 000 000 208 252 252 147 134 134 134 134 203 253 252 252 188 083 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 208 252 252 252 252 252 252 252 252 253 230 153 008 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 049 157 252 252 252 252 252 217 207 146 045 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 007 103 235 252 172 103 024 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
複製程式碼

使用 keras,可以很方便地匯入 MNIST 資料集:

from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
複製程式碼

總體來說,我們的想要得到的網路模型,是有一個固定的輸入輸出的:

  • 輸入為一個 28 * 28 的二維整數陣列
  • 輸出是一個長度為 10 的陣列,依次表示 0-9 的可能性(例如如果有一張圖片 80% 概率為 1, 20% 概率為 7的話,那麼這個陣列就是 [0, 0.8, 0, 0, 0, 0, 0, 0.2, 0, 0]

二、使用 keras 訓練網路

我們想要訓練的模型,由以下幾層網路組成:

  1. 32 個 3x3 卷積核的卷積層
  2. 64 個 3x3 卷積核的卷積層
  3. 取樣因子為 (2, 2) 的池化層
  4. Dropout 層
  5. Flatten 層
  6. ReLu 全連線層
  7. Dropout 層
  8. Softmax 全連線層

用 keras 訓練一個識別 MNIST 的 CNN 網路非常方便,下面是一個官方給出的例子(原始碼在此):

from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 12

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# Save model
model.save('myMnistCNN.h5')
複製程式碼

如果已經安裝好了 keras,直接執行即可:

python mnist_cnn.py
複製程式碼

三、轉換輸出模型

獲得訓練好的 .h5 檔案之後,模型還不能直接使用,因為我們需要對它進行轉編碼,keras-js 提供了一個 python 指令碼來自動執行:

python ./python/encoder.py -q myMnistCNN.h5
複製程式碼

這個指令碼會把 .h5 檔案轉編碼為 keras-js 可讀的格式,裡面包含了訓練好的神經網路的所有模型和引數。

四、使用 keras-js 匯入模型

首先需要引入 keras-js,可以通過 script 標籤直接引入:

<script src="https://unpkg.com/keras-js"></script>
複製程式碼

也可以通過 npm 安裝後使用 webpack 構建引入,參考這裡

接下來就可以直接建立一個 Model,keras-js 會自動載入對應的 bin 檔案:

const model = new KerasJS.Model({
    filepath: '/path/to/mnist_cnn.bin',
    gpu: true,
    transferLayerOutputs: true
})
複製程式碼

初始化完畢之後,就可以用於 MNIST 識別了,輸入是一個長度為 784 的陣列(包含 28*28 各個畫素點的灰度值),輸出是一個長度為 10 的陣列(0-9的概率):

(可以使用上文中給的那個 "3" 的資料範例)

model
  .ready()
  .then(() => {
    // data 是一個長度為 784 的陣列,每一項都介於 0 - 255 之間
    // 這裡我們需要把陣列轉換為 Float32 型別
    const inputData = new Float32Array(data)
    // 識別
    return model.predict(inputData)
  })
  .then(outputData => {
    // 輸出為 0-9 的概率,例如:
    // { output: [0, 0, 0, 0.8, 0, 0, 0.2, 0, 0, 0] }
  })
  .catch(err => {
    // ...
  })
複製程式碼

五、Canvas 實現一個手寫板

最後一步就是實現一個手寫板,具體的程式碼就不放上來了,主要就是通過 mousedownmousemovemouseup 事件來繪製圖形。

繪製完畢之後,呼叫 ctx.getImageData,就可以得到 canvas 內的畫素資料,每個畫素對應四個數值,依次是每個點的 rgba 值,處理之後就可以得到長度為 784 的灰度陣列了。然後使用上文提到的 model.predict 即可。

基於 keras-js 快速實現瀏覽器內的 CNN 手寫數字識別

相關文章