【深度學習筆記】Batch Normalization (BN)
Batch Normalization: Accelerating Deep Network Training b y Reducing Internal Covariate Shift這篇文章是谷歌2015年提出的一個深層網路訓練技巧,Batch Normalization(簡稱BN)不僅可以加快了模型的收斂速度,而且更重要的是在一定程度緩解了深層網路中“梯度彌散”的問題(梯度彌散:,在BN中,通過將activation規範為均值和方差一致的手段使得原本會減小的activation的scale變大),從而使得訓練深層網路模型更加容易和穩定。
BN主要分為三步:
- 求每一個batch的資料均值和方差
- 使用求得的均值和方差對該批次的訓練資料做歸一化,獲得0-1分佈。其中是為了避免分母為0。
- 尺度變換和偏移:將乘以調整數值大小,再加上增加偏移後得到,這裡的控制縮放,控制偏移。由於歸一化後的基本會被限制在正態分佈下,使得網路的表達能力下降,影響到network的capacity。為解決該問題,引入兩個新的引數,這兩個引數是在訓練時由網路學習得到的,如此一來,既可以改變同時也可以保持原輸入,那麼模型的容納能力(capacity)就提升了。
在訓練時,會對同一批的資料的均值和方差進行求解,進而進行歸一化操作。對於預測階段時所使用的均值和方差,可以是來源於訓練集,訓練時每次計算每個batch的方差與均值,為了使得每個batch的方差與均值儘可能的接近整體分佈方差與均值的估計值,這裡通過滑動平均求整個訓練樣本的均值和方差期望值,作為我們進行預測時進行BN的的均值和方差。滑動係數為,當前batch計算的均值和方差為,那麼
均值更新:
方差更新,採用無偏估計:
在caffe的BN層中use_global_stats:如果為真,則使用儲存的均值和方差,否則採用滑動平均計算新的均值和方差。該引數預設的時候,如果是測試階段則等價為真,如果是訓練階段則等價為假。
在tensorflow中,使用bn,注意以下幾項:
1、訓練時,模型輸入引數training=True
def forward(self, inputs, is_training=False, reuse=False):
# set batch norm params
batch_norm_params = {
'decay': self.batch_norm_decay,
'epsilon': 1e-05,
'scale': True,
'is_training': is_training,
'fused': None, # Use fused batch norm if possible.
}
2、訓練時,如果是使用var_list = tf.trainable_variables()是不包含通過滑動平均計算出的均值和方差這兩個引數,所以如下程式碼的方式,令var_list=update_vars
parser.add_argument("--update_part", nargs='*', type=str, default=['tiny_yolo/yolov3_head'],
help="Partially restore part of the model for finetuning. Set [None] to train the whole model.")
# define yolo-v3 model here
yolo_model = tiny_yolo(args.class_num, args.anchors)
with tf.variable_scope('tiny_yolo'):
pred_feature_maps = yolo_model.forward(image, is_training=is_training)
loss = yolo_model.compute_loss(pred_feature_maps, y_true)
y_pred = yolo_model.predict(pred_feature_maps)
update_vars = tf.contrib.framework.get_variables_to_restore(include=args.update_part)
# set dependencies for BN ops
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss[0], var_list=update_vars, global_step=global_step)
3、測試時,模型輸入引數training=False即可
參考資料:
[1] https://arxiv.org/pdf/1502.03167.pdf
[2] https://www.cnblogs.com/skyfsm/p/8453498.html
[3]深度學習中 Batch Normalization為什麼效果好?https://www.zhihu.com/question/38102762
[4]caffe層解讀系列——BatchNorm https://blog.csdn.net/shuzfan/article/details/52729424
相關文章
- batch normalization學習理解筆記BATORM筆記
- 深度學習中 Batch Normalization深度學習BATORM
- BN(Batch Normalization)層的詳細介紹BATORM
- [PyTorch 學習筆記] 6.2 NormalizationPyTorch筆記ORM
- 解毒batch normalizationBATORM
- 深度學習中的Normalization模型深度學習ORM模型
- TensorFlow實現Batch NormalizationBATORM
- Batch Normalization: 如何更快地訓練深度神經網路BATORM神經網路
- 深度學習 筆記一深度學習筆記
- 深度學習keras筆記深度學習Keras筆記
- 深度學習框架Pytorch學習筆記深度學習框架PyTorch筆記
- 深度學習中的Normalization模型(附例項&公式)深度學習ORM模型公式
- 深度學習 DEEP LEARNING 學習筆記(一)深度學習筆記
- 深度學習 DEEP LEARNING 學習筆記(二)深度學習筆記
- 深度學習——loss函式的學習筆記深度學習函式筆記
- 【深度學習】大牛的《深度學習》筆記,Deep Learning速成教程深度學習筆記
- 學習筆記:深度學習中的正則化筆記深度學習
- 李巨集毅深度學習 筆記(四)深度學習筆記
- 深度學習當中的三個概念:Epoch, Batch, Iteration深度學習BAT
- 深度學習推理時融合BN,輕鬆獲得約5%的提速深度學習
- 深度學習入門筆記——Transform的使用深度學習筆記ORM
- 深度學習入門筆記——DataLoader的使用深度學習筆記
- 深度學習筆記(5)Broadcasting in Python 廣播深度學習筆記ASTPython
- 【深度學習】深度學習md筆記總結第1篇:深度學習課程,要求【附程式碼文件】深度學習筆記
- Python深度學習(處理文字資料)--學習筆記(十二)Python深度學習筆記
- 【筆記】動手學深度學習-預備知識筆記深度學習
- 學習筆記【深度學習2】:AI、機器學習、表示學習、深度學習,第一次大衰退筆記深度學習AI機器學習
- 基於深度學習的醫學影像配準學習筆記2深度學習筆記
- 深度學習筆記------卷積神經網路深度學習筆記卷積神經網路
- 深度學習筆記002-線性迴歸深度學習筆記
- 深度學習卷積神經網路筆記深度學習卷積神經網路筆記
- 吳恩達《神經網路與深度學習》課程筆記(1)– 深度學習概述吳恩達神經網路深度學習筆記
- 22張精煉圖筆記,深度學習專項學習必備筆記深度學習
- numpy的學習筆記\pandas學習筆記筆記
- 《深度學習入門》第 2 章 感知機 筆記深度學習筆記
- 李巨集毅深度學習 筆記(七)Auto-encoder深度學習筆記
- 遷移學習中的BN問題遷移學習
- 淺談深度學習訓練中資料規範化(Normalization)的重要性深度學習ORM