SSD演算法程式碼介訓練演算法整體架構

大佬111發表於2018-09-14




主要介紹了訓練模型的一些引數配置資訊,可以看出在訓練指令碼train.py中主要是呼叫train_net.py指令碼中的train_net函式進行訓練的,因此這一篇部落格介紹train_net.py指令碼的內容。

train_net.py這個指令碼一共包含convert_pretrained,get_lr_scheduler,train_net三個函式,其中最重要的是train_net函式,這個函式也是train.py指令碼訓練模型時候呼叫的函式,建議從train_net函式開始看起。

import tools.find_mxnetimport mxnet as mximport loggingimport sysimport osimport importlibimport re# 匯入生成模型可用的資料格式的類,是在dataset資料夾下的iterator.py指令碼中實現的,# 一般採用這種匯入指令碼中類的方式需要在dataset資料夾下寫一個空的__init__.py指令碼才能匯入from dataset.iterator import DetRecordIter 
from train.metric import MultiBoxMetric # 匯入訓練時候的評價標準類# 匯入測試時候的評價標準類,這裡VOC07MApMetric類繼承了MApMetric類,主要內容在MApMetric類中from evaluate.eval_metric import MApMetric, VOC07MApMetric 
from config.config import cfgfrom symbol.symbol_factory import get_symbol_train # get_symbol_train函式來匯入symboldef convert_pretrained(name, args):
    """
    Special operations need to be made due to name inconsistance, etc
    Parameters:
    ---------
    name : str
        pretrained model name
    args : dict
        loaded arguments
    Returns:
    ---------
    processed arguments as dict
    """
    return args# get_lr_scheduler函式就是設計你的學習率變化策略,函式的幾個輸入的意思在這裡都介紹得很清楚了,# lr_refactor_step可以是3或6這樣的單獨數字,也可以是3,6,9這樣用逗號間隔的數字,表示到第3,6,9個epoch的時候就要改變學習率def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
                     num_example, batch_size, begin_epoch):
    """
    Compute learning rate and refactor scheduler
    Parameters:
    ---------
    learning_rate : float
        original learning rate
    lr_refactor_step : comma separated str
        epochs to change learning rate
    lr_refactor_ratio : float
        lr *= ratio at certain steps
    num_example : int
        number of training images, used to estimate the iterations given epochs
    batch_size : int
        training batch size
    begin_epoch : int
        starting epoch
    Returns:
    ---------
    (learning_rate, mx.lr_scheduler) as tuple
    """
    assert lr_refactor_ratio > 0
    iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]    # 學習率的改變一般都是越來越小,不接受學習率越來越大這種策略,在這種情況下采用學習率不變的策略
    if lr_refactor_ratio >= 1: 
        return (learning_rate, None)    else:
        lr = learning_rate
        epoch_size = num_example // batch_size # 表示每個epoch最少包含多少個batch# 這個for迴圈的內容主要是解決當你設定的begin_epoch要大於你的iter_refactor的某些值的時候,# 會按照lr_refactor_ratio改變你的初始學習率,也就是說這個改變是還沒開始訓練的時候就做的。
        for s in iter_refactor: 
            if begin_epoch >= s:
                lr *= lr_refactor_ratio# 如果有上面這個學習率的改變,那麼列印出改變資訊,這樣以後看log也能很清楚地知道當時實際初始學習率是多少。
        if lr != learning_rate: 
            logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))# 這個steps就是你要執行多少個batch才需要改變學習率,因此這個steps是以batch為單位的
        steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]# 這個if條件滿足的話就表示我的begin_epoch比你設定的iter_refactor裡面的所有值都大,那麼我就返回學習率lr,# 至於更改的策略就只能是None了,也就是說用這個lr一直跑到結束,中間就不改變了
        if not steps: 
            return (lr, None)# 最終用mx.lr_scheduler.MultiFactorScheduler函式生成模型可用的lr_scheduler
        lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)        return (lr, lr_scheduler)# 這是train_net.py指令碼中的主要函式def train_net(net, train_path, num_classes, batch_size,
              data_shape, mean_pixels, resume, finetune, pretrained, epoch,
              prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
              momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000, label_pad_width=350,
              nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
              use_difficult=False, class_names=None,
              voc07_metric=False, nms_topk=400, force_suppress=False,
              train_list="", val_path="", val_list="", iter_monitor=0,
              monitor_pattern=".*", log_file=None):
    """
    Wrapper for training phase.
    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger# 這部分內容和生成日誌檔案相關,依賴logging這個庫,if條件中的log_file就是生成的log檔案的路徑和名稱。# 這個logger是RootLogger型別,可以用來輸出提示資訊,# 用法例子:logger.info("Start finetuning with {} from epoch {}".format(ctx, epoch))
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)    # check args# 這一部分主要是檢查一些配置引數是不是異常,比如你的data_shape必須是個int型等
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)    assert len(data_shape) == 3 and data_shape[0] == 3
    if prefix.endswith('_'):
        prefix += '_' + str(data_shape[1])    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]    assert len(mean_pixels) == 3, "must provide all RGB mean values"# 這裡的train_iter是透過呼叫dataset資料夾下的iterator.py指令碼中的DetRecordIter類來得到的,# 簡單講就是從.rec和.lst檔案到模型可以用的資料迭代器的過程。輸入中train_path是你的.rec檔案的路徑,# label_pad_width這個引數在文中的解釋是force padding training and validation labels to sync their labels widths,# train_list是空字串。
    train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
        label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)# 如果你給了驗證集資料的路徑,那麼也生成驗證集資料迭代器,做法和前面訓練集的做法一樣
    if val_path:
        val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
            label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)    else:
        val_iter = None
    # load symbol# 這裡呼叫了symbol資料夾下的symbol_factory.py指令碼的get_symbol_train函式來匯入symbol。這個函式的輸入中,net是一個str,# 程式碼中預設是‘vgg16_reduced’,data_shape是一個tuple,是在前面計算得到的,比如data_shape是(3,300,300),num_classes就是類別數,# 在VOC資料集中,num_classes就是20,nms_thresh是nms操作的引數,預設是0.45,# force_suppress和nms_topk兩個引數都是採用預設的False和400。# 這個函式的輸出net就是最終的檢測網路,是一個symbol。
    net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
        nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)    # define layers with fixed weight/bias# 這一步是設計一些層的引數在模型訓練過程中不變,freeze_layer_pattern是在train.py裡面設定的一個引數,表示要將哪些層的引數固定。# 最後得到的fixed_param_names就是一個list,其中的每個元素就是層引數的名稱,比如conv1_1_weight,是一個str。
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]    else:
        fixed_param_names = None
    # load pretrained or resume from previous state# resume是指你在訓練檢測模型的時候如果訓練到一半但是中斷了,想要從中斷的epoch繼續訓練,# 那麼可以匯入訓練中斷前的那個epoch的.param檔案,# 這個檔案就是檢測模型的引數,從而用這個引數初始化檢測模型,達到斷點繼續訓練的目的。
    ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}"
            .format(ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}"
            .format(ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune        # check what layers mismatch with the loaded parameters
        exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null')
        arg_dict = exe.arg_dict
    fixed_param_names = []        for k, v in arg_dict.items():            if k in args:                if v.shape != args[k].shape:                    del args[k]
                    logging.info("Removed %s" % k)                else:            if not 'pred' in k:
                fixed_param_names.append(k)# 這個if條件是匯入預訓練好的分類模型來初始化檢測模型的引數,其中mxnet.model.checkpoint就是執行這個匯入引數的作用,# 生成的_是分類模型的網路,args是分類模型的引數,型別是dictionary,每個item表示一個層引數,item的內容就是一個引數的NDArray格式。# auxs在這裡是一個空字典。最後呼叫的這個convert_pretrained函式就是該指令碼定義的第一個函式,直接return args,沒做什麼操作。
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}"
            .format(ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)    else:
        logger.info("Experimental: start training from scratch with {}"
            .format(ctx_str))
        args = None
        auxs = None
        fixed_param_names = None
    # helper information
    # 這一部分將前面得到的要固定引數的層資訊列印出來
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')    # init training module# 呼叫mx.mod.Module類初始化一個模型。引數中net就是前面透過get_symbol_train函式匯入的檢測模型的symbol。# logger是和日誌相關的引數。ctx就是你訓練模型時候的cpu或gpu選擇。初始化model的時候就要指定要固定的引數。
    mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
                        fixed_param_names=fixed_param_names)    # fit parameters
 # 這個frequent就是你每隔frequent個batch顯示一次訓練結果(比如損失,準確率等等),程式碼中frequent採用20。
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) # prefix是一個指定的路徑,生成的epoch_end_callback作為最後fit()函式的引數之一,用來指定生成的模型的存放地址。
    epoch_end_callback = mx.callback.do_checkpoint(prefix)# 呼叫get_lr_scheduler()函式生成初始的學習率和學習率變化策略,這個get_lr_scheduler()函式在前面有詳細介紹
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
        lr_refactor_ratio, num_example, batch_size, begin_epoch)# 定義最佳化器的一些引數,比如學習率;momentum(該引數是在sgd演算法中計算下一步更新方向時候會用到,預設是0.9);# wd是正則項的係數,一般採用0.0001到0.0005,程式碼中預設是0.0005;lr_scheduler是學習率的更新策略,# 比如你間隔20個epoch就把學習率降為原來的0.1倍等;# rescale_grad引數如果你是一塊GPU跑,就是預設的1,如果是多GPU,那麼相當於在做梯度更新的時候需要合併多個GPU的結果,# 這裡ctx就是代表你是用cpu還是gpu,以及gpu的話是採用哪幾塊gpu。
    optimizer_params={'learning_rate':learning_rate,                      'momentum':momentum,                      'wd':weight_decay,                      'lr_scheduler':lr_scheduler,                      'clip_gradient':None,                      'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }# 這個monitor一般是除錯時候採用,預設訓練模型的時候這個monitor是None,也就是iter_monitor預設是0
    monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None
    # run fit net, every n epochs we run evaluation network to get mAP# 這一步是對評價指標的選擇,指令碼中中預設採用voc07_metric,ovp_thresh預設是0.5,# 表示計算MAp時類別相同的預測框和真實框的IOU值的閾值。
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)    else:
        valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)# 模型訓練的入口,這個mod只有檢測網路的結構資訊,而fit的arg_params引數則是指定了用來初始化這個檢測模型的引數,# 這些引數來自預訓練好的分類模型。# 如果你在除錯模型的時候執行到fit這個函式,進入這個函式的話就進入到mxnet專案的base_module.py指令碼,# 裡面包含了引數初始化和模型前後向的具體操作。
    mod.fit(train_iter, # 訓練資料
            val_iter, # 測試資料
            eval_metric=MultiBoxMetric(), # 訓練時的評價指標
            validation_metric=valid_metric, # 測試時的評價指標# 每多少個batch顯示結果,這個batch_end_callback引數是由mx.callback.Speedometer()函式生成的,# 這個函式的輸入包括batch_size和間隔
            batch_end_callback=batch_end_callback, 
# 每個epoch結束後,得到的.param檔案存放地址,這個epoch_end_callback由mx.callback,do_checkpoint()函式生成,# 這個函式的輸入就是存放地址。
            epoch_end_callback=epoch_end_callback, 
            optimizer='sgd', # 最佳化演算法採用sgd,也就是隨機梯度下降
            optimizer_params=optimizer_params, # 最佳化器的一些引數
            begin_epoch=begin_epoch, # epoch的初始值
            num_epoch=end_epoch, # 一共要訓練多少個epoch
            initializer=mx.init.Xavier(), # 其他引數的初始化方式
            arg_params=args, # 匯入的模型的引數,就是你預訓練的模型的引數
            aux_params=auxs, # 匯入的模型的引數的均值方差
            allow_missing=True, # 是否允許一些引數缺失
            monitor=monitor) # 如果monitor為None的話,就沒什麼用了,因為fit()函式預設monitor引數為None

本文來源:https://blog.csdn.net/u014380165/article/details/79332365

來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/31548646/viewspace-2214183/,如需轉載,請註明出處,否則將追究法律責任。

相關文章