TensorFlow Hub--用一行程式碼完成遷移學習

wam_weicher發表於2018-07-04

TensorFlow Hub簡介

TensorFlow Hub是一個用於促進機器學習模型中可複用部分再次進行探索與釋出的庫,主要將預訓練過的TensorFlow模型片段再次利用到新的任務上。(可以理解為做遷移學習)

要使用TensorFlow Hub需要你本地安裝的TensorFlow的版本在1.7以上(TensorFlow的安裝配置過程本文不做介紹,若有需要,可以參考此文

通過以下命令即可安裝TensorFlow Hub

pip install tensorflow-hub

下載好的TensorFlow Hub版本資訊如下圖所示

版本資訊.png


TensorFlow Hub使用

方便起見,我們使用TensorFlow Hub官方提供的花卉圖片集來作為我們的資料集,網路條件允許的讀者,可以使用如下命令下載圖片集

cd ~

curl -LO http://download.tensorflow.org/example_images/flower_photos.tgz

tar xzf flower_photos.tgz

下載不下來也沒事,本人貼心的準備了百度雲連結 連結: https://pan.baidu.com/s/1NVl8uUU7iVktxE0g7Oa5jw 密碼: ivt8

下載下來之後解壓即可。

解壓之後我們可以看到,flower_photos資料夾下包含了如下幾個子資料夾,每一個子資料夾的名字都代表了其中圖片的標籤,如dandelion(蒲公英)資料夾下的所有圖片皆為蒲公英的圖片。

資料集目錄.png

有了圖片集了,我們還需要用於遷移學習的訓練程式碼。 同樣,也可以通過如下命令下載用於遷移學習的訓練程式碼

mkdir ~/example_code

cd ~/example_code

curl -LO https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py

讀者也可自行前往TensorFlow Hub的官方GitHub倉庫下載。

我們需要使用的是hub-master\examples\image_retraining路徑下的retrain.py檔案,這個py檔案是官方為我們準備好的用作圖片分類遷移學習的樣板程式碼。

關鍵的地方到了

如何用一行程式碼實現複雜的遷移學習

因為本人的tensorflow-gpu庫裝在單獨的環境下,所以我需要啟用Anaconda Prompt,然後啟用我所要用的指定環境

activate my_special_env

(若用系統Path路徑下的python環境,則忽略上一步)

再將路徑cd到你本地retrain.py檔案的路徑下(具體的路徑根據自己的實際情況更改,下圖為本人的路徑地址)

路徑地址.png

最後,最關鍵的一行程式碼來了:

python retrain.py --image_dir ~/flower_photos

(~/flower_photos為下載下來解壓好的圖片集的路徑,請根據實際情況修改)

然後,遷移學習的訓練就跑起來了~!


TensorFlow Hub小貼士

1.為什麼訓練一直卡在downloading位置?

因為TensorFlow Hub是通過url的形式獲取的網上釋出的模型,如果你有幸成功跑完整個訓練,你可以在C:\Users\你的使用者名稱\AppData\Local\Temp\tfhub_modules中看到一個資料夾和一個txt文字 其中,txt文字的內容如下:

Module: https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1

Download Time: 2018-07-02 18:21:36.051380

Downloader Hostname: WSJ-LAPTOP (PID:18172)

從中不難看出,TensorFlow Hub是將從網上下載的inception_v3模型用作我們剛才遷移學習的預訓練模型(通過閱讀retrain.py的原始碼你也能發現這一點),所以如果網路狀態不好或者翻牆不順的話,那就自然是下載不下來的......

2.訓練完後的的模型檔案儲存在哪兒?

預設是儲存在tmp/資料夾下的,因為本人retrain.py檔案所在碟符為E盤,所以所有訓練生成的瓶頸檔案,ckpt檔案,pb檔案,label檔案都在E:\tmp路徑下。

3.訓練完的模型中,輸入和輸出的tensor分別是什麼?

根據官方文件,輸入的tensor是"Placeholder",輸出的tensor是"final_result"。讀者可以使用官方的影象分類預測程式碼來測試已訓練好的模型。此文不展開描述測試的具體步驟,讀者可自行查閱

4.除了預設的模型,我們還可以用哪些預訓練模型?

retrain.py中--tfhub_moduled的預設值即為inception_v3模型的url,如需替換模型,可以參考官方文件,其中列舉了所有可用到的已釋出的官方預訓練模型。

tfhub_module.png


若想了解更多的資料,如超引數的設定等詳細配置資訊,請大家自行查閱官方GitHub倉庫 TensorFlow Hub最新中文網站:https://tensorflow.google.cn/hub/



若您覺得本文章對您有用,請您為我點上一顆小心心以表支援。感謝!

相關文章