瀏覽器中的手寫數字識別

雲水木石發表於2019-04-25

隨著TensorFlow 2.0 alpha的釋出,TensorFlow.js更新到首個正式版本1.0,TensorFlow的官網也增加了TensorFlow.js的文件,這說明TensorFlow.js不再是一個試驗品。作為一名瀏覽器核心研發工程師,對TensorFlow.js自然充滿了興趣。

Javascript語言這些年來四處攻城掠地,服務端有Node.js,移動前端開發更是大熱,就連桌面應用也有JS的身影,比如最近火熱的Visual Studio Code,現在又滲透到人工智慧領域。不得不感概,當年匆忙設計出來,飽受批評的一門指令碼語言,竟然生命力這麼頑強。

閒話少說,下面就來看看在瀏覽器中訓練模型是怎樣的一種體驗。

我之前寫過一系列的《一步步提高手寫數字的識別率(1)(2)(3)》,手寫數字識別是一個非常好的入門專案,所以在這裡我就以手寫數字識別為例,說明在瀏覽器中如何訓練模型。這裡就不從最簡單的線性迴歸模型開始,而是直接選用卷積神經網路。

和python程式碼中訓練模型的步驟一樣,使用TensorFlow.js在瀏覽器中訓練模型的步驟主要有4步:

  • 載入資料。
  • 定義模型結構。
  • 訓練模型並監控其訓練時的表現。
  • 評估訓練的模型。

載入資料

有過機器學習知識的朋友,應該對MNIST資料集不陌生,這是一套28x28大小手寫數字的灰度影像,包含55000個訓練樣本,10000個測試樣本,另外還有5000個交叉驗證資料樣本。tensorflow python提供了一個封裝類,可以直接載入MNIST資料集,在TensorFlow.js中需要自己寫程式碼載入:

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;

const TRAIN_TEST_RATIO = 5 / 6;

const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

const MNIST_IMAGES_SPRITE_PATH =
    'mnist_images.png';
const MNIST_LABELS_PATH =
    'mnist_labels_uint8';

/**
 * A class that fetches the sprited MNIST dataset and returns shuffled batches.
 *
 * NOTE: This will get much easier. For now, we do data fetching and
 * manipulation manually.
 */
export class MnistData {
  constructor() {
    this.shuffledTrainIndex = 0;
    this.shuffledTestIndex = 0;
  }

  async load() {
    // Make a request for the MNIST sprited image.
    const img = new Image();
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d');
    const imgRequest = new Promise((resolve, reject) => {
      img.crossOrigin = '';
      img.onload = () => {
        img.width = img.naturalWidth;
        img.height = img.naturalHeight;

        const datasetBytesBuffer =
            new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

        const chunkSize = 5000;
        canvas.width = img.width;
        canvas.height = chunkSize;

        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
          const datasetBytesView = new Float32Array(
              datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
              IMAGE_SIZE * chunkSize);
          ctx.drawImage(
              img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
              chunkSize);

          const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

          for (let j = 0; j < imageData.data.length / 4; j++) {
            // All channels hold an equal value since the image is grayscale, so
            // just read the red channel.
            datasetBytesView[j] = imageData.data[j * 4] / 255;
          }
        }
        this.datasetImages = new Float32Array(datasetBytesBuffer);

        resolve();
      };
      img.src = MNIST_IMAGES_SPRITE_PATH;
    });

    const labelsRequest = fetch(MNIST_LABELS_PATH);
    const [imgResponse, labelsResponse] =
        await Promise.all([imgRequest, labelsRequest]);

    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

    // Create shuffled indices into the train/test set for when we select a
    // random dataset element for training / validation.
    this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
    this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

    // Slice the the images and labels into train and test sets.
    this.trainImages =
        this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.trainLabels =
        this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    this.testLabels =
        this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  }

  nextTrainBatch(batchSize) {
    return this.nextBatch(
        batchSize, [this.trainImages, this.trainLabels], () => {
          this.shuffledTrainIndex =
              (this.shuffledTrainIndex + 1) % this.trainIndices.length;
          return this.trainIndices[this.shuffledTrainIndex];
        });
  }

  nextTestBatch(batchSize) {
    return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
      this.shuffledTestIndex =
          (this.shuffledTestIndex + 1) % this.testIndices.length;
      return this.testIndices[this.shuffledTestIndex];
    });
  }

  nextBatch(batchSize, data, index) {
    const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
    const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

    for (let i = 0; i < batchSize; i++) {
      const idx = index();

      const image =
          data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
      batchImagesArray.set(image, i * IMAGE_SIZE);

      const label =
          data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
      batchLabelsArray.set(label, i * NUM_CLASSES);
    }

    const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
    const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

    return {xs, labels};
  }
}
複製程式碼

程式碼中,載入一個 mnist_images.png 圖片,該圖片是所有MNIST資料集的影像拼接而來(檔案很大,大約10M),另外載入一個 mnist_labels_uint8 文字檔案,包含所有的MNIST資料集對應的標籤。

需要注意的是,這只是一種載入MNIST資料集的方法,你也可以使用一個手寫數字一張圖片的MNIST資料集,分次載入多個圖片檔案。

上述程式碼實現了一個MnistData類,它有兩個公共方法:

  • nextTrainBatch(batchSize):從訓練集中返回一組隨機影像及其標籤。
  • nextTestBatch(batchSize):從測試集中返回一批影像及其標籤

為了檢驗上述程式碼是否工作正常,可以寫一段程式碼顯示載入的資料:

async function showExamples(data) {
  // Create a container in the visor
  const surface =
    tfvis.visor().surface({ name: 'Input Data Examples', tab: 'Input Data'});

  // Get the examples
  const examples = data.nextTestBatch(20);
  const numExamples = examples.xs.shape[0];

  // Create a canvas element to render each example
  for (let i = 0; i < numExamples; i++) {
    const imageTensor = tf.tidy(() => {
      // Reshape the image to 28x28 px
      return examples.xs
        .slice([i, 0], [1, examples.xs.shape[1]])
        .reshape([28, 28, 1]);
    });

    const canvas = document.createElement('canvas');
    canvas.width = 28;
    canvas.height = 28;
    canvas.style = 'margin: 4px;';
    await tf.browser.toPixels(imageTensor, canvas);
    surface.drawArea.appendChild(canvas);

    imageTensor.dispose();
  }
}

async function run() {
  const data = new MnistData();
  await data.load();
  await showExamples(data);
}

document.addEventListener('DOMContentLoaded', run);
複製程式碼

瀏覽器中的手寫數字識別

定義模型結構

關於卷積神經網路,可以參閱《一步步提高手寫數字的識別率(3)》這篇文章,這裡定義的卷積網路結構為:

CONV -> MAXPOOlING -> CONV -> MAXPOOLING -> FC -> SOFTMAX

每個卷積層使用RELU啟用函式,程式碼如下:

function getModel() {
  const model = tf.sequential();

  const IMAGE_WIDTH = 28;
  const IMAGE_HEIGHT = 28;
  const IMAGE_CHANNELS = 1;

  // In the first layer of out convolutional neural network we have
  // to specify the input shape. Then we specify some paramaters for
  // the convolution operation that takes place in this layer.
  model.add(tf.layers.conv2d({
    inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
    kernelSize: 5,
    filters: 8,
    strides: 1,
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
  }));

  // The MaxPooling layer acts as a sort of downsampling using max values
  // in a region instead of averaging.
  model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));

  // Repeat another conv2d + maxPooling stack.
  // Note that we have more filters in the convolution.
  model.add(tf.layers.conv2d({
    kernelSize: 5,
    filters: 16,
    strides: 1,
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
  }));
  model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));

  // Now we flatten the output from the 2D filters into a 1D vector to prepare
  // it for input into our last layer. This is common practice when feeding
  // higher dimensional data to a final classification output layer.
  model.add(tf.layers.flatten());

  // Our last layer is a dense layer which has 10 output units, one for each
  // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
  const NUM_OUTPUT_CLASSES = 10;
  model.add(tf.layers.dense({
    units: NUM_OUTPUT_CLASSES,
    kernelInitializer: 'varianceScaling',
    activation: 'softmax'
  }));


  // Choose an optimizer, loss function and accuracy metric,
  // then compile and return the model
  const optimizer = tf.train.adam();
  model.compile({
    optimizer: optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
  });

  return model;
}
複製程式碼

如果有過tensorflow python程式碼編寫經驗,上面的程式碼應該很容易理解。

訓練模型並監控其訓練時的表現

在瀏覽器中訓練,也可以批量輸入影像資料,可以指定batch size,epoch輪次。

async function train(model, data) {
  const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
  const container = {
    name: 'Model Training', styles: { height: '1000px' }
  };
  const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

  const BATCH_SIZE = 512;
  const TRAIN_DATA_SIZE = 5500;
  const TEST_DATA_SIZE = 1000;

  const [trainXs, trainYs] = tf.tidy(() => {
    const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
    return [
      d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
      d.labels
    ];
  });

  const [testXs, testYs] = tf.tidy(() => {
    const d = data.nextTestBatch(TEST_DATA_SIZE);
    return [
      d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
      d.labels
    ];
  });

  return model.fit(trainXs, trainYs, {
    batchSize: BATCH_SIZE,
    validationData: [testXs, testYs],
    epochs: 10,
    shuffle: true,
    callbacks: fitCallbacks
  });
}
複製程式碼

和python程式碼相比,fit多了一個callbacks引數。需要注意的是,訓練過程比較長,我們不能阻塞瀏覽器主執行緒,程式碼中大多時候需要非同步方法。而callbacks可以通知主執行緒更新,這裡借用了tfvis庫,可以視覺化訓練過程(類似於tensorboard),但這裡是在網頁上顯示。

瀏覽器中的手寫數字識別

評估訓練的模型

評估時喂入測試集,程式碼也和python版本的類似:

function doPrediction(model, data, testDataSize = 500) {
  const IMAGE_WIDTH = 28;
  const IMAGE_HEIGHT = 28;
  const testData = data.nextTestBatch(testDataSize);
  const testxs = testData.xs.reshape([testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1]);
  const labels = testData.labels.argMax([-1]);
  const preds = model.predict(testxs).argMax([-1]);

  testxs.dispose();
  return [preds, labels];
}
複製程式碼

如果我們希望更直觀的顯示每個類別的精確度以及錯誤的分類,可以藉助tfvis庫:

async function showAccuracy(model, data) {
  const [preds, labels] = doPrediction(model, data);
  const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
  const container = {name: 'Accuracy', tab: 'Evaluation'};
  tfvis.show.perClassAccuracy(container, classAccuracy, classNames);

  labels.dispose();
}

async function showConfusion(model, data) {
  const [preds, labels] = doPrediction(model, data);
  const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
  const container = {name: 'Confusion Matrix', tab: 'Evaluation'};
  tfvis.render.confusionMatrix(
      container, {values: confusionMatrix}, classNames);

  labels.dispose();
}
複製程式碼

評估結果如下圖所示:

瀏覽器中的手寫數字識別

關於TensorFlow.js

TensowFlow.js藉助於WebGL,可以加速訓練過程。如果瀏覽器不支援WebGL,也不會出錯,只不過會走CPU的路徑,當然速度也會慢很多。

雖然通過WebGL,也利用上了GPU,但對於大規模深度學習模型,在瀏覽器中訓練也不現實,這個時候我們也可以在server上訓練好模型,轉換為TensorFlow.js可用的模型格式,在瀏覽器中載入模型,並進行推斷,關於這個話題,請關注後續的文章。

以上示例有完整的程式碼,點選閱讀原文,跳轉到我在github上建的示例程式碼。 另外,你也可以在瀏覽器中直接訪問:ilego.club/ai/index.ht… ,直接體驗瀏覽器中的機器學習。

參考文獻:

  1. tensorflow官網
  2. TensorFlow.js — Handwritten digit recognition with CNNs

你還可以讀

  1. 一步步提高手寫數字的識別率(1)(2)(3)
  2. TensorFlow.js簡介

image

相關文章