【深度學習筆記】Batch Normalization (BN)

通訊程式猿發表於2019-01-07

Batch Normalization: Accelerating Deep Network Training b y Reducing Internal Covariate Shift這篇文章是谷歌2015年提出的一個深層網路訓練技巧,Batch Normalization(簡稱BN)不僅可以加快了模型的收斂速度,而且更重要的是在一定程度緩解了深層網路中“梯度彌散”的問題(梯度彌散:0.9^{30}\approx 0.04,在BN中,通過將activation規範為均值和方差一致的手段使得原本會減小的activation的scale變大),從而使得訓練深層網路模型更加容易和穩定。

BN主要分為三步:

  1. 求每一個batch的資料均值和方差
  2. 使用求得的均值和方差對該批次的訓練資料做歸一化,獲得0-1分佈。其中\epsilon是為了避免分母為0。
  3. 尺度變換和偏移:將\hat{x_i}乘以\gamma調整數值大小,再加上\beta增加偏移後得到y_i,這裡的\gamma控制縮放,\beta控制偏移。由於歸一化後的\hat{x_i}基本會被限制在正態分佈下,使得網路的表達能力下降,影響到network的capacity。為解決該問題,引入兩個新的引數\gamma ,\beta,這兩個引數是在訓練時由網路學習得到的,如此一來,既可以改變同時也可以保持原輸入,那麼模型的容納能力(capacity)就提升了。

在訓練時,會對同一批的資料的均值和方差進行求解,進而進行歸一化操作。對於預測階段時所使用的均值和方差,可以是來源於訓練集,訓練時每次計算每個batch的方差與均值,為了使得每個batch的方差與均值儘可能的接近整體分佈方差與均值的估計值,這裡通過滑動平均求整個訓練樣本的均值和方差期望值,作為我們進行預測時進行BN的的均值和方差。滑動係數為\lambda,當前batch計算的均值和方差為\mu,\sigma,那麼

均值更新:\mu_{new} = \lambda\mu_{old}+\mu

方差更新,採用無偏估計:\sigma_{new} = \lambda\sigma_{old}+\sigma

在caffe的BN層中use_global_stats:如果為真,則使用儲存的均值和方差,否則採用滑動平均計算新的均值和方差。該引數預設的時候,如果是測試階段則等價為真,如果是訓練階段則等價為假。

 

在tensorflow中,使用bn,注意以下幾項:

1、訓練時,模型輸入引數training=True

    def forward(self, inputs, is_training=False, reuse=False):
        # set batch norm params
        batch_norm_params = {
            'decay': self.batch_norm_decay,
            'epsilon': 1e-05,
            'scale': True,
            'is_training': is_training,
            'fused': None,  # Use fused batch norm if possible.
        }

2、訓練時,如果是使用var_list = tf.trainable_variables()是不包含通過滑動平均計算出的均值\mu和方差\sigma這兩個引數,所以如下程式碼的方式,令var_list=update_vars

parser.add_argument("--update_part", nargs='*', type=str, default=['tiny_yolo/yolov3_head'],
                    help="Partially restore part of the model for finetuning. Set [None] to train the whole model.")

# define yolo-v3 model here
yolo_model = tiny_yolo(args.class_num, args.anchors)
with tf.variable_scope('tiny_yolo'):
    pred_feature_maps = yolo_model.forward(image, is_training=is_training)
loss = yolo_model.compute_loss(pred_feature_maps, y_true)
y_pred = yolo_model.predict(pred_feature_maps)


update_vars = tf.contrib.framework.get_variables_to_restore(include=args.update_part)

# set dependencies for BN ops
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss[0], var_list=update_vars, global_step=global_step)

3、測試時,模型輸入引數training=False即可

 

 

參考資料:

[1] https://arxiv.org/pdf/1502.03167.pdf

[2] https://www.cnblogs.com/skyfsm/p/8453498.html

[3]深度學習中 Batch Normalization為什麼效果好?https://www.zhihu.com/question/38102762

[4]caffe層解讀系列——BatchNorm https://blog.csdn.net/shuzfan/article/details/52729424

[5]https://www.cnblogs.com/hrlnw/p/7227447.html

相關文章