- 原文地址:How to Use TensorFlow Mobile in Android Apps
- 原文作者:Ashraff Hathibelagal
- 譯文出自:掘金翻譯計劃
- 本文永久連結:github.com/xitu/gold-m…
- 譯者:luochen
- 校對者:ALVINYEH LeeSniper
TensorFlow 是當今最流行的機器學習框架之一,您利用它可以輕鬆建立和訓練深度模型 —— 通常也稱為深度前饋神經網路,這些模型可以解決各種複雜問題,如影像分類、目標檢測和自然語言理解。TensorFlow Mobile 是一個旨在幫助您在移動應用中利用這些模型的庫。
在本教程中,我將向您展示如何在 Android Studio 專案中使用 TensorFlow Mobile。
前期準備
為了能夠跟上教程,您需要做的是:
- Android Studio 3.0 或更高版本
- TensorFlow 1.5.0 或更高版本
- 一臺能夠執行 API level 21 或更高的安卓裝置
- 以及對 TensorFlow 框架的基本瞭解
1、建立模型
在我們開始使用 TensorFlow Mobile 之前,我們需要一個已經訓練好的 TensorFlow 模型。我們現在建立一個。
我們的模型將非常基礎,類似於異或門,接受兩個輸入,它們可以是零或一,然後有一個輸出。如果兩個輸入相同,則輸出為零。此外,因為它將是一個深度模型,它將有兩個隱藏層,一個有四個神經元,另一個有三個神經元。您可以自由改變隱藏層的數量以及它們包含的神經元的數量。
為了保持本教程的簡潔,我們將使用 TFLearn,這是一個很受歡迎的 TensorFlow 封裝框架,它提供更加直接而簡潔的 API,而不是直接使用低階別的 TensorFlow API。如果您還沒安裝它,請使用以下命令將其安裝在 TensorFlow 虛擬環境中:
pip install tflearn
複製程式碼
要開始建立模型,最好在空目錄中先新建一個名為 create_model.py 的 Python 指令碼,然後使用您最喜歡的文字編輯器開啟它。
在檔案裡,我們需要做的第一件事是匯入 TFLearn API。
import tflearn
複製程式碼
接下來,我們必須建立訓練資料。對於我們的簡單模型,只有四種可能的輸入和輸出,類似於異或門真值表的內容。
X = [
[0, 0],
[0, 1],
[1, 0],
[1, 1]
]
Y = [
[0], # Desired output for inputs 0, 0
[1], # Desired output for inputs 0, 1
[1], # Desired output for inputs 1, 0
[0] # Desired output for inputs 1, 1
]
複製程式碼
為隱藏層中的所有神經元分配初始權重時,最好的做法通常是使用從均勻分佈中產生的隨機數。可以使用 uniform()
方法生成這些值。
weights = tflearn.initializations.uniform(minval = -1, maxval = 1)
複製程式碼
此時,我們可以開始構建神經網路層。要建立輸入層,我們必須使用 input_data()
方法,它允許我們指定網路可以接受的輸入數量。一旦輸入層準備就緒,我們可以多次呼叫 fully_connected()
方法來向網路新增更多層。
# 輸入層
net = tflearn.input_data(
shape = [None, 2],
name = 'my_input'
)
# 隱藏層
net = tflearn.fully_connected(net, 4,
activation = 'sigmoid',
weights_init = weights
)
net = tflearn.fully_connected(net, 3,
activation = 'sigmoid',
weights_init = weights
)
# 輸出層
net = tflearn.fully_connected(net, 1,
activation = 'sigmoid',
weights_init = weights,
name = 'my_output'
)
複製程式碼
注意,在上面的程式碼中,我們賦予了輸入層和輸出層有意義的名稱。這麼做很重要,因為我們在使用安卓應用中的網路時需要它們。還要注意隱藏層和輸出層使用了 sigmoid
啟用函式。您可以試試其他啟用函式,例如 softmax
、tanh
和 relu
。
作為我們網路的最後一層,我們必須使用 regression()
函式建立一個迴歸層,該函式需要一些超引數作為其引數,例如網路的學習率以及它應該使用的優化器和損失函式。以下程式碼向您展示瞭如何使用隨機梯度下降(簡稱 SGD)作為優化器函式,均方誤差作為損失函式:
net = tflearn.regression(net,
learning_rate = 2,
optimizer = 'sgd',
loss = 'mean_square'
)
複製程式碼
接下來,為了讓 TFLearn 框架知道我們的網路模型實際上是一個深度神經網路模型,我們須要呼叫 DNN()
函式。
model = tflearn.DNN(net)
複製程式碼
模型現在已經準備好了。我們現在要做的就是使用我們之前建立的訓練資料進行訓練。因此,呼叫模型的 fit()
方法,並指定訓練資料與訓練週期。由於訓練資料非常小,我們的模型將需要數千次迭代才能達到合理的精度。
model.fit(X, Y, 5000)
複製程式碼
一旦訓練完成,我們可以呼叫模型的 predict()
方法來檢查它是否生成期望的輸出。以下程式碼展示瞭如何檢查所有有效輸入的輸出:
print("1 XOR 0 = %f" % model.predict([[1,0]]).item(0))
print("1 XOR 1 = %f" % model.predict([[1,1]]).item(0))
print("0 XOR 1 = %f" % model.predict([[0,1]]).item(0))
print("0 XOR 0 = %f" % model.predict([[0,0]]).item(0))
複製程式碼
如果現在執行 Python 指令碼,您應該看到如下所示的輸出:
請注意,輸出不會完全是 0 或 1。而是接近 0 或 1 的浮點數。因此,在使用輸出時,可能需要使用 Python 的 round()
函式。
除非我們在訓練後明確儲存模型,否則只要程式結束,我們就會失去模型。幸運的是,對於 TFLearn,只需呼叫 save()
方法即可儲存模型。但是,為了能夠在 TensorFlow Mobile 中使用儲存的模型,在儲存之前,我們必須確保移除所有訓練相關的操作。這些操作都在 tf.GraphKeys.TRAIN_OPS 集合中。以下程式碼展示了怎麼去移除相關操作:
# 移除訓練相關的操作
with net.graph.as_default():
del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
# 儲存模型
model.save('xor.tflearn')
複製程式碼
如果您再次執行該指令碼,您會發現它會生成檢查點檔案、後設資料檔案、索引檔案和資料檔案,所有這些檔案一起使用時可以快速重建我們訓練好的模型。
2、固化模型
除了儲存模型外,我們還必須先固化模型,然後才能將其與 TensorFlow Mobile 配合使用。正如您可能已經猜到的那樣,固化模型的過程涉及將其所有變數轉換為常量。此外,固化模型必須是符合 Google Protocol Buffers 序列化格式的單個二進位制檔案。
新建一個名為 freeze_model.py 的 Python 指令碼,並使用文字編輯器開啟它。我們將在這個檔案中編寫固化的模型程式碼來。
由於 TFLearn 沒有任何固化模型的功能,我們現在必須直接使用 TensorFlow API。通過將以下行新增到檔案來匯入它們:
import tensorflow as tf
複製程式碼
整個指令碼里面,我們將使用單個 TensorFlow 會話。我們使用 Session
類的建構函式建立會話。
with tf.Session() as session:
# 程式碼的其他部分在這
複製程式碼
此時,我們必須通過呼叫 import_meta_graph()
函式並將模型的後設資料檔案的名稱傳遞給它來建立 Saver
物件,除了返回 Saver
物件外,import_meta_graph()
函式還會自動將模型的圖定義新增到會話的圖定義中。
一旦建立了儲存器(saver),我們可以通過呼叫 restore()
方法來初始化圖定義中存在的所有變數,該方法需要包含模型最新檢查點檔案的目錄路徑。
my_saver = tf.train.import_meta_graph('xor.tflearn.meta')
my_saver.restore(session, tf.train.latest_checkpoint('.'))
複製程式碼
此時,我們可以呼叫 convert_variables_to_constants()
函式來建立一個固化的圖定義,其中模型的所有變數都替換成常量。作為其輸入,函式需要當前會話、當前會話的圖定義以及包含模型輸出層名稱的列表。
frozen_graph = tf.graph_util.convert_variables_to_constants(
session,
session.graph_def,
['my_output/Sigmoid']
)
複製程式碼
呼叫固化圖定義的 SerializeToString()
方法為我們提供了模型的二進位制 protobuf 表示。通過使用 Python 基本的檔案 I/O,我建議您把它儲存為一個名為 frozen_model.pb 的檔案。
with open('frozen_model.pb', 'wb') as f:
f.write(frozen_graph.SerializeToString())
複製程式碼
現在可以執行指令碼來生成固化模型。
我們現在擁有開始使用 TensorFlow Mobile 所需的一切。
3、Android Studio 專案設定
TensorFlow Mobile 庫可在 JCenter 上使用,所以我們可以直接將它新增為 app
模組 build.gradle 檔案中的 implementation
依賴項。
implementation 'org.tensorflow:tensorflow-android:1.7.0'
複製程式碼
要把固化的模型新增到專案中,請將 frozen_model.pb 檔案放置到專案的 assets 資料夾中。
4、初始化 TensorFlow 介面
TensorFlow Mobile 提供了一個簡單的介面,我們可以使用它與我們的固化模型進行互動。要建立介面,請使用 TensorFlowInferenceInterface
類的建構函式,該類需要一個 AssetManager
例項和固化模型的檔名。
thread {
val tfInterface = TensorFlowInferenceInterface(assets,
"frozen_model.pb")
// More code here
}
複製程式碼
在上面的程式碼中,您可以看到我們正在產生一個新的執行緒。這是為了確保應用的 UI 保持響應,雖然不必要,但建議這樣做。
為了保證 TensorFlow Mobile 能夠正確讀取我們模型的檔案,現在讓我們嘗試列印模型圖中所有操作的名稱。為了得到對圖的引用,我們可以使用介面的 graph()
方法,並獲取所有操作,即圖的 operations()
方法。以下程式碼告訴您該怎麼做:
val graph = tfInterface.graph()
graph.operations().forEach {
println(it.name())
}
複製程式碼
如果現在執行該應用,則應該能夠看到在 Android Studio 的 Logcat 視窗中列印的十幾個操作名稱。如果固化模型時沒有出錯,我們可以在這些名稱中找到輸入和輸出層的名稱:my_input/X 和 my_output/Sigmoid。
5、使用模型
為了用模型進行預測,我們將資料輸入到輸入層,在輸出層得到資料。將資料輸入到輸入層需要使用介面的 feed()
方法,該方法需要輸入層的名稱、含有輸入資料的陣列以及陣列的維數。以下程式碼展示如何將數字 0
和 1
輸入到輸入層:
tfInterface.feed("my_input/X",
floatArrayOf(0f, 1f), 1, 2)
複製程式碼
資料載入到輸入層後,我們必須使用 run()
方法進行推斷操作,該方法需要輸出層的名稱。一旦操作完成,輸出層將包含模型的預測。為了將預測結果載入到 Kotlin 陣列中,我們可以使用 fetch()
方法。以下程式碼顯示瞭如何執行此操作:
tfInterface.run(arrayOf("my_output/Sigmoid"))
val output = floatArrayOf(-1f)
tfInterface.fetch("my_output/Sigmoid", output)
複製程式碼
您現在可以執行該應用來檢視模型的預測是否正確。
可以更改輸入到輸入層的數字,以確認模型的預測始終正確。
總結
您現在知道如何建立一個簡單的 TensorFlow 模型以及在安卓應用上通過 TensorFlow Mobile 去使用該模型。不過不必拘泥於自己的模型,用您今天學到的東西,使用更大的模型對您來說應該沒有任何問題。例如 MobileNet 以及 Inception,這些都可以在 TensorFlow 的 模型園 裡找到。但是請注意,這些模型會使 APK 更大,從而給使用低端裝置的使用者造成問題。
要了解有關 TensorFlow Mobile 的更多資訊,請參閱 官方文件.
掘金翻譯計劃 是一個翻譯優質網際網路技術文章的社群,文章來源為 掘金 上的英文分享文章。內容覆蓋 Android、iOS、前端、後端、區塊鏈、產品、設計、人工智慧等領域,想要檢視更多優質譯文請持續關注 掘金翻譯計劃、官方微博、知乎專欄。