tensorflow使用BN—Batch Normalization
你要的答案或許都在這裡:小鵬的部落格目錄
MachineLP的Github(歡迎follow):https://github.com/MachineLP
我的GitHub:https://github.com/MachineLP/train_cnn-rnn-attention 自己搭建的一個框架,包含模型有:vgg(vgg16,vgg19), resnet(resnet_v2_50,resnet_v2_101,resnet_v2_152), inception_v4, inception_resnet_v2等。
注意:不要隨便加BN,有些問題加了後會導致loss變大。
上一篇是 Batch Normalization的原理介紹,看一下tf的實現,加到卷積後面和全連線層後面都可:
(1)
訓練的時候:is_training為True。
import tensorflow as tf
import numpy as np
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import moving_averages
def bn(x, is_training):
x_shape = x.get_shape()
params_shape = x_shape[-1:]
axis = list(range(len(x_shape) - 1))
beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer())
gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer())
moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer(), trainable=False)
moving_variance = _get_variable('moving_variance', params_shape, initializer=tf.ones_initializer(), trainable=False)
# These ops will only be preformed when training.
mean, variance = tf.nn.moments(x, axis)
update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
mean, variance = control_flow_ops.cond(
is_training, lambda: (mean, variance),
lambda: (moving_mean, moving_variance))
return tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
函式:
tf.nn.batch_normalization()
def batch_normalization(x,
mean,
variance,
offset,
scale,
variance_epsilon,
name=None):
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 1
- 2
- 3
- 4
- 5
- 6
- 7
Args:
- x: Input
Tensor
of arbitrary dimensionality. - mean: A mean
Tensor
. - variance: A variance
Tensor
. - offset: An offset
Tensor
, often denoted β in equations, or None. If present, will be added to the normalized tensor. - scale: A scale
Tensor
, often denoted γ in equations, orNone
. If present, the scale is applied to the normalized tensor. - variance_epsilon: A small float number to avoid dividing by 0.
- name: A name for this operation (optional).
- Returns: the normalized, scaled, offset tensor.
對於卷積,x:[bathc,height,width,depth]
對於卷積,我們要feature map中共享 γi 和 βi ,所以 γ,β的維度是[depth]
另外,這裡有使用batch normalization的示例:martin-gorner/tensorflow-mnist-tutorial
還可以參考:resnet:https://github.com/MachineLP/tensorflow-resnet
還可以看大師之作:CNN和RNN中如何引入BatchNorm
訓練好的模型載入:tensorflow中batch normalization的用法
相關文章
- 【深度學習筆記】Batch Normalization (BN)深度學習筆記BATORM
- BN(Batch Normalization)層的詳細介紹BATORM
- TensorFlow實現Batch NormalizationBATORM
- 卷積神經網路CNN(2)—— BN(Batch Normalization) 原理與使用過程詳解卷積神經網路CNNBATORM
- 解毒batch normalizationBATORM
- 深度學習中 Batch Normalization深度學習BATORM
- batch normalization學習理解筆記BATORM筆記
- Batch Normalization: 如何更快地訓練深度神經網路BATORM神經網路
- 深度學習(二十九)Batch Normalization 學習筆記深度學習BATORM筆記
- Synchronized bnsynchronized
- Layer NormalizationORM
- 【BATCH】BATCH-CopyBAT
- MySQL使用Batch批量處理MySqlBAT
- 批量歸一化BN
- jdbc batchJDBCBAT
- codeurjc/spring-mail-batch:使用Spring Batch批次傳送電子郵件SpringAIBAT
- tensorflow:使用conda安裝tensorflow
- 深度學習中的Normalization模型深度學習ORM模型
- TSM BATCH模式BAT模式
- 遷移學習中的BN問題遷移學習
- 你能充分信任Oracle DBA嗎?BNOracle
- [PyTorch 學習筆記] 6.2 NormalizationPyTorch筆記ORM
- MyBatis Batch Update Exception使用foreach批量update出錯MyBatisException
- 陪你解讀Spring Batch(一)Spring Batch介紹SpringBAT
- 5 Fansites Q & A BatchBAT
- Batch Scripting TutorialBAT
- Tensorflow-基礎使用
- 陪你解讀Spring Batch(二)帶你入手Spring BatchSpringBAT
- 神經網路基礎部件-BN層詳解神經網路
- LayerNorm層歸一化和bn的區別ORM
- Batch入門教程(2)BAT
- Learning with Mini-BatchBAT
- 聊聊jdbc的batch操作JDBCBAT
- Spring Batch專題SpringBAT
- nuget packages batch installPackageBAT
- Using svn in CLI with BatchBAT
- SAP 批次管理(Batch management)BAT
- Spring Batch 簡介SpringBAT