深度學習工程模板(DL Project Template),簡化載入資料、構建網路、訓練模型和預測樣本的流程。
原始碼:https://github.com/SpikeKing/DL-Project-Template
使用方式
下載工程
git clone https://github.com/SpikeKing/DL-Project-Template
複製程式碼
建立和啟用虛擬環境
virtualenv venv
source venv/bin/activate
複製程式碼
安裝Python依賴庫
pip install -r requirements.txt
複製程式碼
開發流程
- 定義自己的資料載入類,繼承DataLoaderBase;
- 定義自己的網路結構類,繼承ModelBase;
- 定義自己的模型訓練類,繼承TrainerBase;
- 定義自己的樣本預測類,繼承InferBase;
- 定義自己的配置檔案,寫入實驗的相關引數;
執行訓練模型和預測樣本操作。
示例工程
識別MNIST庫中手寫數字,工程simple_mnist
訓練:
python main_train.py -c configs/simple_mnist_config.json
複製程式碼
預測:
python main_test.py -c configs/simple_mnist_config.json -m simple_mnist.weights.10-0.24.hdf5
複製程式碼
網路結構
TensorBoard
工程架構
框架圖
資料夾結構
├── bases
│ ├── data_loader_base.py - 資料載入基類
│ ├── infer_base.py - 預測樣本(推斷)基類
│ ├── model_base.py - 網路結構(模型)基類
│ ├── trainer_base.py - 訓練模型基類
├── configs - 配置資料夾
│ └── simple_mnist_config.json
├── data_loaders - 資料載入資料夾
│ ├── __init__.py
│ ├── simple_mnist_dl.py
├── experiments - 實驗資料資料夾
│ └── simple_mnist - 實驗名稱
│ ├── checkpoints - 儲存的模型和引數
│ │ └── simple_mnist.weights.10-0.24.hdf5
│ ├── images - 圖片
│ │ └── model.png
│ └── logs - 日誌,如TensorBoard
│ └── events.out.tfevents.1524034653.wang
├── infers - 推斷資料夾
│ ├── __init__.py
│ ├── simple_mnist_infer.py
├── main_test.py - 預測樣本入口
├── main_train.py - 訓練模型入口
├── models - 網路結構資料夾
│ ├── __init__.py
│ ├── simple_mnist_model.py
├── requirements.txt - 依賴庫
├── trainers - 訓練模型資料夾
│ ├── __init__.py
│ ├── simple_mnist_trainer.py
└── utils - 工具資料夾
├── __init__.py
├── config_utils.py - 配置工具類
├── np_utils.py - NumPy工具類
├── utils.py - 其他工具類
複製程式碼
主要元件
DataLoader
操作步驟:
- 建立自己的載入資料類,繼承DataLoaderBase基類;
- 覆寫
get_train_data()
和get_test_data()
,返回訓練和測試資料;
Model
操作步驟:
- 建立自己的網路結構類,繼承ModelBase基類;
- 覆寫
build_model()
,建立網路結構; - 在構造器中,呼叫
build_model()
;
注意:plot_model()
支援繪製網路結構;
Trainer
操作步驟:
- 建立自己的訓練類,繼承TrainerBase基類;
- 引數:網路結構model、訓練資料data;
- 覆寫
train()
,fit資料,訓練網路結構;
注意:支援在訓練中呼叫callbacks,額外新增模型儲存、TensorBoard、FPR度量等。
Infer
操作步驟:
- 建立自己的預測類,繼承InferBase基類;
- 覆寫
load_model()
,提供模型載入功能; - 覆寫
predict()
,提供樣本預測功能;
Config
定義在模型訓練過程中所需的引數,JSON格式,支援:學習率、Epoch、Batch等引數。
Main
訓練:
- 建立配置檔案config;
- 建立資料載入類dl;
- 建立網路結構類model;
- 建立訓練類trainer,引數是訓練和測試資料、模型;
- 執行訓練類trainer的train();
預測:
- 建立配置檔案config;
- 處理預測樣本test;
- 建立預測類infer;
- 執行預測類infer的predict();
感謝
參考Tensorflow-Project-Template工程
By C. L. Wang @ 美圖雲事業部