caffe層解讀系列——BatchNorm

shuzfan發表於2016-10-03

之前也寫過一篇介紹 Batch Normalization 的文章,原理還不是很清楚的童鞋可以移步看一下。後來看到caffe中的實現,發現還是有很大不同之處,所以這裡介紹一些caffe中的BN。

—————————— 可選引數 ——————————

可選引數定義在 src\caffe\proto\caffe.proto 中,共有3個:

message BatchNormParameter {
  // 如果為真,則使用儲存的均值和方差,否則採用滑動平均計算新的均值和方差。
  // 該引數預設的時候,如果是測試階段則等價為真,如果是訓練階段則等價為假。
  optional bool use_global_stats = 1;

  // 滑動平均的衰減係數,預設為0.999
  optional float moving_average_fraction = 2 [default = .999];

  // 分母附加值,防止除以方差時出現除0操作,預設為1e-5
  optional float eps = 3 [default = 1e-5];
}

—————————— 前向傳播 ——————————

說前向傳播之前,先注意幾點:

(1)均值和方差的個數:

  // 如果bottom是1維的,則均值和方差個數為1,否則等於通道數
  if (bottom[0]->num_axes() == 1)
    channels_ = 1;
  else
    channels_ = bottom[0]->shape(1);

(2)均值和方差的更新:

均值和方差採用的是滑動平均的更新方式。因此,BN層共儲存了3個數值:均值滑動和、方差滑動和、滑動係數和。
計算公式如下:

設滑動係數 moving_average_fraction 為 \(\lambda\) ,m = bottom[0]->count() / channels_,儲存的三個數值(均值滑動和、方差滑動和、滑動係數和)分別為 \(\mu_{old}, \sigma_{old}, s_{old}\), 當前batch計算的均值和方差為\(\mu, \sigma\)。則:

\(s_{new}=\lambda s_{old}+1\);

\(\mu_{new} = \lambda \mu_{old}+\mu \);

對於方差,採用的是無偏估計,
\(\sigma_{new} = \lambda\sigma_{old}+m \sigma \qquad if(m>1),m = m/(m-1)\);

(3)均值和方差的使用:

caffe到目前仍然沒有實現和論文原文保持一致的BN層,即沒有 \(\alpha和\beta\) 引數,因此更新公式就比較簡單了,為每一個channel施加如下公式:

\(x = \frac{x-\mu}{\sigma}\);

但是需要注意的是,我們儲存的是均值和方差的滑動和,因此還要做一些處理。

還是設儲存的三個數值(均值滑動和、方差滑動和、滑動係數和)分別為 \(\mu_{old}, \sigma_{old}, s_{old}\)。

首先要計算一個縮放係數\(s = 1/s_{old} \qquad if(s_{old}==0),s = 1\),則:

\(\mu = s*\mu_{old}\);

方差的計算還是要稍微複雜一點:

\(\sigma = s*\sigma_{old}\);

\(\sigma = {(\sigma + eps)}^{0.5}\)

相關文章