機器學習,Hello World from Javascript!

發表於2017-12-11

機器學習,Hello World from Javascript!

導語

JavaScript 適合做機器學習嗎?這是一個問號。但每一位開發者都應該瞭解機器學習解決問題的思維和方法,並思考:它將會給我們的工作帶來什麼?同樣,演算法能力可能會是下一階段工程師的標配。

本文旨在通過講解識別手寫字的處理過程,帶讀者瞭解機器學習解決問題的一般過程。本文適合以下背景的讀者閱讀:

  • 你不需要具備 Python、C++ 的程式設計能力:全文使用 JavaScript 作為程式語言,且不依賴任何第三方庫實現機器學習演算法。
  • 你不需要具備演算法能力和高數的背景,本文機器學習演算法的實現不過 20 行程式碼。

作者學識有限,文章中難免會有疏漏,歡迎指正。

機器學習中的 Hello World

就像我們學習程式語言一樣,我們的第一個嘗試就是在終端命令列中輸出的 “Hello World”。機器學習中的 “Hello World” 便是識別手寫字資料集。手寫字是形如下面的影像:

手寫字圖

我們可以編寫一個網頁程式,提供手寫板的功能來捕獲使用者的輸入,並返回我們識別的數字:使用者在手寫板內寫下 0 到 9 中的任意一個數字,另一側則顯示我們識別的結果。正如 Keras.js 提供的示例那樣[1]:

Keras.js - MNIST 示例截圖

如何編寫出這樣的手寫識別程式來獲取使用者的手寫輸入不是我們這篇文章的重點。我們的重點是,當我們的程式得到這樣一張影像的資料後,如何識別出這組資料表示的數字?

資料的表示和收集

人類能夠從影像中獲得資訊,但程式如何知道 A 圖是表示 1,B 圖是表示 2 ?因此我們需要確定資料的表示方式:用怎樣的一種方式來在程式中表達一張白底黑字的影像它的畫素點分佈及點的黑白度?

觀察 Keras.js 的示例,你會發現手寫板的面積是 240px * 240px。即手寫板內有 57,600 個畫素點。我們可以把它們平鋪開來,並且用 0 到 1 的數值表示每個點的黑白度,其中越接近 1 則表示該畫素點越黑,那麼就可以用一個數值矩陣來表示手寫字:

手寫字的資料表示

手寫板程式獲得使用者的輸入並生成影像後,識別程式將影像轉換成我們需要的資料格式。影像識別是另一個廣泛的課題,在這裡不再展開。我們會直接使用 MNIST 資料集[2],它的資料表示方式正如上面所描述的那樣,只是 MNIST 資料集中每一張圖片是包含 28 * 28 個畫素點的。

確定了資料的表示方式,接下來我們還需要對每個資料的實際含義進行標識。

回想一下我們自己是如何認識這些數字的?即我們是怎樣認定影像中的 1 形狀表示的就是數字 1?————事物的認識。認知是由他人教育的。

同樣,在機器學習中,我們也需要“教育”機器:A 這樣的畫素點排序就是 1, B 這樣的畫素點排序是 2。這就是訓練資料

為了收集訓練資料,我們可以隨機找人在手寫識別程式中畫數字,然後標識它的結果,最終以任何的形式(文字、表格…)儲存。以手寫字為例,我們可以用文字的方式儲存,格式可以是這樣:

其中每一行代表一個訓練資料,使用 “|” 分割資料的表示和它對應的數字。

在本文中我們將直接使用 MNIST 資料集,因此如何收集資料在這裡不再展開。在機器學習中經常會使用公開的資料集來進行訓練和測試

通過確定資料的表示和收集,我們可以瞭解到的是:

  • 資料是一切機器學習的基礎;
  • 訓練資料的好壞將會影響到我們機器學習演算法預測的準確率:
    • 想象一下如果某些資料我們標識錯誤,把 1 標識成 2;
    • 想象一下如果訓練資料中有大量的重複值,或某個數字的資料量特別大而另外一些數字的資料量很小。

準備資料

我們收集到的資料可能會以任何的一種形式儲存,例如文字、表格、二進位制檔案等等。MNIST 資料集是使用二進位制儲存的,因此在程式中我們需要將它轉換為 Javascript 比較容易操作的的資料格式,例如陣列。

本文中我們將使用一個 NPM 包 mnist[3] 提供的,已經轉換好的資料,它的格式如下:

其中 input 是影像的資料表示,output 是影像實際代表的數字。output 使用了一種叫做 One-Hot 的編碼方式,它一共有 10 個項,為 1 的項就是它表示的數字(第一項為 1 則代表是 0,第二項為 1 則代表是 2 ,以此類推)。

選擇一種演算法

通過上面的資料準備,我們已經把一個現實中的問題轉化成了一個數學問題:給定 728 個 0 到 1 之間數值的特徵,應該將它分類到 0 ~ 9 哪個數字中?

這就是機器學習中的主要任務——分類。有很多的機器學習演算法可以用來解決分類問題,文字將使用 k-近鄰演算法(k-NN)[4]來解決這個問題,因為它非常有效且容易理解。

k-近鄰演算法概述

在一個 10 * 10 的二維平面內畫一條線把它分成 2 個區域(A/B)。假設我們不知道線是如何畫的,但現已知有 4 個點,a 點座標是 (1, 1) 屬於區域 A,b 點座標是 (2, 2) 屬於區域 A,c 點座標是 (9, 9) 屬於區域 B,d 點座標是 (8, 8) 屬於區域 B。這時候再給定一個 e 點座標是 (8.5, 8.5) ,請問它最有可能在哪個區域內?

Index Point 1 Point 2 Area
a 1 1 A
b 2 2 A
c 9 9 B
d 8 8 B
e 8.5 8.5 ?

二維平面圖示例

絕大多數人都會說“可能是 B”。我們是如何得出這個答案的?——因為它和 c, d “看起來更接近一些,更有可能在同一個區域”。同樣的推論可以延伸至三維、四維甚至更多緯度的資料中。MNIST 的資料表示就是 728 個特徵的多緯資料,k-近鄰演算法同樣適用。

存在一個訓練樣本集,並且樣本集中每個資料都存在標籤,即我們知道樣本集中每一資料與所屬分類的對應關係。輸入沒有標籤的新資料後,將新資料的每個特徵與樣本集中資料對應的特徵進行比較,然後演算法提取樣本特徵最近鄰的分類標籤。一般來說,我們只選擇樣本資料集中前 k 個最相似的資料,這就是 k-近鄰演算法的 k 的出處。
——《機器學習實戰》k 近鄰演算法

兩個向量之間的距離可以通過歐幾里得距離公式求得:

歐幾里得距離公式

於是實現一個 k-NN 演算法就很簡單了:

測試演算法

為了驗證我們的演算法的效果,我們需要對其進行測試。這就需要引入測試資料。在機器學習中通常會將收集到的資料通過一定的方法劃分為訓練資料和測試資料然後用於訓練和測試。如何劃分資料在這裡不展開,在本示例中,我們按照 80:20 的比例來劃分訓練和測試資料,互斥性和隨機性由 MNIST 庫進行保證。

拿到訓練和測試資料後我們就可以對上一步編寫的演算法進行測試了,我們用錯誤率來評估演算法的可靠性,錯誤率越低則越可靠:

如無意外,你的終端將會輸出這樣的結果:

kNN執行結果

最終錯誤率的值大約是 5%。這個結果好嗎?並不好。我們可以通過改變 k 的值、改變訓練樣本的數目影響 k-近鄰演算法的錯誤率,讀者可以嘗試改變這些變數值觀察錯誤率的變化。實際上,只要將 k-近鄰演算法稍加改良,我們就能夠把錯誤率降到 1% 以下!

MNIST資料集中kNN演算法的效率

表格中列出了一些 k-近鄰演算法對 MNIST 資料集進行測試的錯誤率,圖片來自 http://yann.lecun.com/exdb/mnist/

我們也應該注意到的是,我們的演算法在 8000 條訓練資料集和 2000 條測試資料集上進行測試,執行了 325 秒!這是一個很差的結果。在實際生產環境中,我們不僅應該關注準確率也應該關注演算法的執行效率。

使用演算法

只要測試的演算法效果符合預期,我們就可以將演算法部署到生產環境進行使用了。我們可以將演算法和手寫識別程式結合起來,完成一整套獲取輸入 -> 演算法預測 -> 輸出結果的流程:首先手寫識別程式將使用者輸入的影像轉換為我們期望的資料格式,然後執行我們的演算法獲取預測的分類。程式碼可能是這樣:

很遺憾,在執行演算法時我們還是看到了 trainingImages 的存在。這意味著每次預測我們的機器都需要給訓練資料準備格外的儲存空間。假設訓練資料很大(這很常見),則會給我們的生產環境機器造成巨大的記憶體壓力。

每次呼叫演算法還需要傳入訓練資料的方式即浪費儲存空間也不優雅,它只能作為我們的示例進行使用。

進一步思考

本文我們使用 Javascript 實現了一個非常簡單的機器學習演算法,並用其來測試 MNIST 資料集,完整程式碼實現在這個倉庫中。這只是一個簡單的示例,但從中我們瞭解到了機器學習的基本概念和解決問題的一般過程。進一步思考,上面的流程中每一步都可能被優化:

  • 對於手寫字,還有沒有其他的資料表示方式?例如我們非要用 0 到 1 的數值來表示點的黑白度嗎?
  • 訓練資料集是越大越好嗎?例如我們將手寫字所有的特徵排列組合 (28^28) 個資料量作為訓練資料集;
  • 如何調整演算法引數以獲得最佳的收益(準確率和效率)?

參考資料

  1. Keras.js – MNIST 示例
  2. MNIST 資料集
  3. mnist – NPM
  4. k-NN

題圖:https://unsplash.com/photos/OFpzBycm3u0 By @jens johnsson

相關文章