快速開啟你的第一個專案:TensorFlow專案架構模板

黃小天發表於2018-02-06

專案連結:https://github.com/Mrgemy95/Tensorflow-Project-Template


TensorFlow 專案模板


簡潔而精密的結構對於深度學習專案來說是必不可少的,在經過多次練習和 TensorFlow 專案開發之後,本文作者提出了一個結合簡便性、最佳化檔案結構和良好 OOP 設計的 TensorFlow 專案模板。該模板可以幫助你快速啟動自己的 TensorFlow 專案,直接從實現自己的核心思想開始。

這個簡單的模板可以幫助你直接從構建模型、訓練等任務開始工作。

目錄


  • 概述

  • 詳述

  • 專案架構

  • 資料夾結構

  • 主要元件

  • 模型

  • 訓練器

  • 資料載入器

  • 記錄器

  • 配置

  • Main

  • 未來工作

概述


簡言之,本文介紹的是這一模板的使用方法,例如,如果你希望實現 VGG 模型,那麼你應該:

在模型資料夾中建立一個名為 VGG 的類,由它繼承「base_model」類

  1.   class VGGModel(BaseModel):

  2.        def __init__(self, config):

  3.            super(VGGModel, self).__init__(config)

  4.            #call the build_model and init_saver functions.

  5.            self.build_model()

  6.            self.init_saver()

覆寫這兩個函式 "build_model",在其中執行你的 VGG 模型;以及定義 TensorFlow 儲存的「init_saver」,隨後在 initalizer 中呼叫它們。

  1.    def build_model(self):

  2.        # here you build the tensorflow graph of any model you want and also define the loss.

  3.        pass

  4.     def init_saver(self):

  5.        #here you initalize the tensorflow saver that will be used in saving the checkpoints.

  6.        self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)

在 trainers 資料夾中建立 VGG 訓練器,繼承「base_train」類。

  1.        class VGGTrainer(BaseTrain):

  2.        def __init__(self, sess, model, data, config, logger):

  3.            super(VGGTrainer, self).__init__(sess, model, data, config, logger)

覆寫這兩個函式「train_step」、「train_epoch」,在其中寫入訓練過程的邏輯。

  1.       def train_epoch(self):

  2.        """

  3.       implement the logic of epoch:

  4.       -loop ever the number of iteration in the config and call teh train step

  5.       -add any summaries you want using the sammary

  6.        """

  7.        pass

  8.    def train_step(self):

  9.        """

  10.       implement the logic of the train step

  11.       - run the tensorflow session

  12.       - return any metrics you need to summarize

  13.       """

  14.        pass

在主檔案中建立會話,建立以下物件:「Model」、「Logger」、「Data_Generator」、「Trainer」與配置:

  1.      sess = tf.Session()

  2.    # create instance of the model you want

  3.    model = VGGModel(config)

  4.    # create your data generator

  5.    data = DataGenerator(config)

  6.    # create tensorboard logger

  7.    logger = Logger(sess, config)

向所有這些物件傳遞訓練器物件,透過呼叫「trainer.train()」開始訓練。

  1.       trainer = VGGTrainer(sess, model, data, config, logger)

  2.    # here you train your model

  3.    trainer.train()

你會看到模板檔案、一個示例模型和訓練資料夾,向你展示如何快速開始你的第一個模型。

詳述


模型架構


快速開啟你的第一個專案:TensorFlow專案架構模板快速開啟你的第一個專案:TensorFlow專案架構模板


資料夾結構


  1.       ├──  base

  2.   ├── base_model.py   - this file contains the abstract class of the model.

  3.   └── ease_train.py - this file contains the abstract class of the trainer.

  4. ├── model               -This folder contains any model of your project.

  5.   └── example_model.py

  6. ├── trainer             -this folder contains trainers of your project.

  7.   └── example_trainer.py

  8.  

  9. ├──  mains              - here's the main/s of your project (you may need more than one main.

  10. │                        

  11. │  

  12. ├──  data _loader  

  13. │    └── data_generator.py  - here's the data_generator that responsible for all data handling.

  14. └── utils

  15.     ├── logger.py

  16.     └── any_other_utils_you_need

主要元件


模型

  • 基礎模型

基礎模型是一個必須由你所建立的模型繼承的抽象類,其背後的思路是:絕大多數模型之間都有很多東西是可以共享的。基礎模型包含:

  • Save-此函式可儲存 checkpoint 至桌面。

  • Load-此函式可載入桌面上的 checkpoint。

  • Cur-epoch、Global_step counters-這些變數會跟蹤訓練 epoch 和全域性步。

  • Init_Saver-一個抽象函式,用於初始化儲存和載入 checkpoint 的操作,注意:請在要實現的模型中覆蓋此函式。

  • Build_model-是一個定義模型的抽象函式,注意:請在要實現的模型中覆蓋此函式。

  • 你的模型

以下是你在模型中執行的地方。因此,你應該:

  • 建立你的模型類並繼承 base_model 類。

  • 覆寫 "build_model",在其中寫入你想要的 tensorflow 模型。

  • 覆寫"init_save",在其中你建立 tensorflow 儲存器,以用它儲存和載入檢查點。

  • 在 initalizer 中呼叫"build_model" 和 "init_saver"

訓練器

  • 基礎訓練器

基礎訓練器(Base trainer)是一個只包裝訓練過程的抽象的類。

  • 你的訓練器

以下是你應該在訓練器中執行的。

  • 建立你的訓練器類,並繼承 base_trainer 類。

  • 覆寫這兩個函式,在其中你執行每一步和每一 epoch 的訓練過程。

資料載入器

這些類負責所有的資料操作和處理,並提供一個可被訓練器使用的易用介面。

記錄器(Logger)

這個類負責 tensorboard 總結。在你的訓練器中建立一個有關所有你想要的 tensorflow 變數的詞典,並將其傳遞給 logger.summarize()。

配置

我使用 Json 作為配置方法,接著解析它,因此寫入所有你想要的配置,然後用"utils/config/process_config"解析它,並把這個配置物件傳遞給所有其他物件。

Main

以下是你整合的所有之前的部分。

1. 解析配置檔案。

2. 建立一個 TensorFlow 會話。

3. 建立 "Model"、"Data_Generator" 和 "Logger"例項,並解析所有它們的配置。

4. 建立一個"Trainer"例項,並把之前所有的物件傳遞給它。

5. 現在你可透過呼叫"Trainer.train()"訓練你的模型。

未來工作


未來,該專案計劃透過新的 TensorFlow 資料集 API 替代資料載入器。

相關文章