深度學習工程模板

技術小能手發表於2018-10-25

深度學習工程模板(DL Project Template),簡化載入資料、構建網路、訓練模型和預測樣本的流程。

原始碼:https://github.com/SpikeKing/DL-Project-Template

使用方式

下載工程

git clone https://github.com/SpikeKing/DL-Project-Template
複製程式碼

建立和啟用虛擬環境

virtualenv venv
source venv/bin/activate
複製程式碼

安裝Python依賴庫

pip install -r requirements.txt
複製程式碼

開發流程

  1. 定義自己的資料載入類,繼承DataLoaderBase;
  2. 定義自己的網路結構類,繼承ModelBase;
  3. 定義自己的模型訓練類,繼承TrainerBase;
  4. 定義自己的樣本預測類,繼承InferBase;
  5. 定義自己的配置檔案,寫入實驗的相關引數;

執行訓練模型和預測樣本操作。

示例工程

識別MNIST庫中手寫數字,工程simple_mnist

訓練:

python main_train.py -c configs/simple_mnist_config.json
複製程式碼

預測:

python main_test.py -c configs/simple_mnist_config.json -m simple_mnist.weights.10-0.24.hdf5
複製程式碼

網路結構

網路結構

TensorBoard

TensorBoard

工程架構

框架圖

架構

資料夾結構

├── bases
│   ├── data_loader_base.py             - 資料載入基類
│   ├── infer_base.py                   - 預測樣本(推斷)基類
│   ├── model_base.py                   - 網路結構(模型)基類
│   ├── trainer_base.py                 - 訓練模型基類
├── configs                             - 配置資料夾
│   └── simple_mnist_config.json
├── data_loaders                        - 資料載入資料夾
│   ├── __init__.py
│   ├── simple_mnist_dl.py
├── experiments                         - 實驗資料資料夾
│   └── simple_mnist                    - 實驗名稱
│       ├── checkpoints                 - 儲存的模型和引數
│       │   └── simple_mnist.weights.10-0.24.hdf5
│       ├── images                      - 圖片
│       │   └── model.png
│       └── logs                        - 日誌,如TensorBoard
│           └── events.out.tfevents.1524034653.wang
├── infers                              - 推斷資料夾
│   ├── __init__.py
│   ├── simple_mnist_infer.py
├── main_test.py                        - 預測樣本入口
├── main_train.py                       - 訓練模型入口
├── models                              - 網路結構資料夾
│   ├── __init__.py
│   ├── simple_mnist_model.py
├── requirements.txt                    - 依賴庫
├── trainers                            - 訓練模型資料夾
│   ├── __init__.py
│   ├── simple_mnist_trainer.py
└── utils                               - 工具資料夾
    ├── __init__.py
    ├── config_utils.py                 - 配置工具類
    ├── np_utils.py                     - NumPy工具類
    ├── utils.py                        - 其他工具類
複製程式碼

主要元件

DataLoader

操作步驟:

  1. 建立自己的載入資料類,繼承DataLoaderBase基類;
  2. 覆寫get_train_data()get_test_data(),返回訓練和測試資料;

Model

操作步驟:

  1. 建立自己的網路結構類,繼承ModelBase基類;
  2. 覆寫build_model(),建立網路結構;
  3. 在構造器中,呼叫build_model()

注意:plot_model()支援繪製網路結構;

Trainer

操作步驟:

  1. 建立自己的訓練類,繼承TrainerBase基類;
  2. 引數:網路結構model、訓練資料data;
  3. 覆寫train(),fit資料,訓練網路結構;

注意:支援在訓練中呼叫callbacks,額外新增模型儲存、TensorBoard、FPR度量等。

Infer

操作步驟:

  1. 建立自己的預測類,繼承InferBase基類;
  2. 覆寫load_model(),提供模型載入功能;
  3. 覆寫predict(),提供樣本預測功能;

Config

定義在模型訓練過程中所需的引數,JSON格式,支援:學習率、Epoch、Batch等引數。

Main

訓練:

  1. 建立配置檔案config;
  2. 建立資料載入類dl;
  3. 建立網路結構類model;
  4. 建立訓練類trainer,引數是訓練和測試資料、模型;
  5. 執行訓練類trainer的train();

預測:

  1. 建立配置檔案config;
  2. 處理預測樣本test;
  3. 建立預測類infer;
  4. 執行預測類infer的predict();

感謝

參考Tensorflow-Project-Template工程

By C. L. Wang @ 美圖雲事業部

相關文章