在上一篇文章《瀏覽器中的手寫數字識別》中,講到在瀏覽器中訓練出一個卷積神經網路模型,用來識別手寫數字。值得注意的是,這個訓練過程是在瀏覽器中完成的,使用的是客戶端的資源。
雖然TensorFlow.js的願景是機器學習無處不在,即使是在手機、嵌入式裝置上,只要執行有瀏覽器,都可以訓練人工智慧模型,但是考慮到手機、嵌入式裝置有限的計算能力(雖然手機效能不斷飛躍),複雜的人工智慧模型還是交給更為強大的伺服器來訓練比較合適。況且目前主流的機器學習採用的是python語言,要讓廣大機器學習工程師從python轉向js,估計大家也不會答應。
如果是這樣的話,那TensorFlow.js推出還有何意義呢?
這個問題其實和TensorFlow Lite類似,我們可以在伺服器端訓練,在手機上使用訓練出的模型進行推導,通常推導並不需要那麼強大的計算能力。
在本文,我們將探索如何在TensorFlow.js中載入預訓練的機器學習模型,完成圖片分類任務。
在TensorFlow官網,訪問 www.tensorflow.org/js/models/ 這個網址,可以看到裡面有實時姿態預測模型、目標檢測模型、語音識別模型、分類模型等等:
這裡我們選擇MobileNets模型。MobileNets是一種小型、低延遲、低耗能模型,滿足各種資源受限的使用場景,可用於分類、檢測、嵌入和分割,功能上類似於其他流行的大型模型(如Inception)。 MobileNets在延遲、大小和準確性之間取得了平衡。
有兩種使用MobileNets模型的方案:
- 直接呼叫MobileNets模型的JS封裝庫
- 自己編寫程式碼載入json格式的MobileNets模型
直接呼叫MobileNets模型的JS封裝庫
JS封裝庫直接將MobileNets模型封裝為JS物件,我們就像呼叫普通的JS物件那樣,呼叫物件方法,完成模型載入、推斷。
比如訪問 github.com/tensorflow/… ,我們可以看到該mobilenet物件提供兩個主要的API:
mobilenet.load(
version?: 1,
alpha?: 0.25 | .50 | .75 | 1.0
)
複製程式碼
引數:
- 版本:MobileNet版本號。1表示MobileNet V1,2表示使用MobileNet V2。預設值為1。
- alpha:較小的alpha會降低精度,但會提高效能。預設值為1.0。
model.classify(
img: tf.Tensor3D | ImageData | HTMLImageElement |
HTMLCanvasElement | HTMLVideoElement,
topk?: number
)
複製程式碼
引數:
- img:進行分類的Tensor或image元素。
- topk:要返回多少個Top概率。預設值為3。
藉助於封裝的JS庫,在瀏覽器中使用MobileNets就相當簡單了:
<html>
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.1"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"> </script>
</head>
<body>
<img id="img" src="cat.jpg"></img>
<script>
const img = document.getElementById('img');
// Load the model.
mobilenet.load().then(model => {
// Classify the image.
model.classify(img).then(predictions => {
console.log('Predictions: ');
console.log(predictions);
});
});
</script>
</body>
</html>
複製程式碼
注意: 這裡的js程式碼會去google storage 載入MobileNets的JSON格式模型,而由於一些不能說的原因,國內無法訪問到,請自行翻牆。
這個示例寫的比較簡單,從瀏覽器控制檯輸出log,顯示結果,在chrome瀏覽器中可以開啟開發者工具檢視:
載入json格式的MobileNets模型
使用封裝好的JS物件確實方便,但使用自己訓練的模型時,並沒有人為我們提供封裝物件。這個時候我們就要考慮自行載入模型,並進行推斷。在JS世界,JSON是使用得非常普遍的資料交換格式。TensorFlow.js也採用JSON作為模型格式,也提供了工具進行轉換。
本來這裡想詳細寫一下如何載入json格式的MobileNets模型,但由於MobileNets的JS模型託管在Google伺服器上,國內無法訪問,所以這裡先跳過這一步。在下一篇文章中我將說明如何從現有的TensorFlow模型轉換為TensorFlow.js模型,並載入之,敬請關注!
以上示例有完整的程式碼,點選閱讀原文,跳轉到我在github上建的示例程式碼。 另外,你也可以在瀏覽器中直接訪問:ilego.club/ai/index.ht… ,直接體驗瀏覽器中的機器學習。
參考文獻:
你還可以讀
- 一步步提高手寫數字的識別率(1)(2)(3)
- TensorFlow.js簡介
- 瀏覽器中的手寫數字識別