深度學習推理時融合BN,輕鬆獲得約5%的提速
批歸一化(Batch Normalization)因其可以加速神經網路訓練、使網路訓練更穩定,而且還有一定的正則化效果,所以得到了非常廣泛的應用。但是,在推理階段,BN層一般是可以完全融合到前面的卷積層的,而且絲毫不影響效能。
Batch Normalization是谷歌研究員於2015年提出的一種歸一化方法,其思想非常簡單,一句話概括就是,對一個神經元(或者一個卷積核)的輸出減去統計得到的均值,除以標準差,然後乘以一個可學習的係數,再加上一個偏置,這個過程就完成了。
下面我們簡單介紹一下BN訓練時怎麼做,推理的時候為什麼可以融合,以及怎麼樣融合。
一. BN訓練時如何做
二. BN推理時怎麼做
三. 在框架中如何融合
下面是來自博文[1]中的一個PyTorch例子,將ResNet18中一個卷積+BN層融合後,融合前後輸出的差值為-6.10425390790148e-11,也就是誤差在百億分之一,基本就是0了。
import torch
import torchvision
def fuse(conv, bn):
fused = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
# setting weights
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
# setting bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros( conv.weight.size(0) )
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
torch.sqrt(bn.running_var + bn.eps)
)
fused.bias.copy_( b_conv + b_bn )
return fused
# Testing
# we need to turn off gradient calculation because we didn't write it
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
resnet18 = torchvision.models.resnet18(pretrained=True)
# removing all learning variables, etc
resnet18.eval()
model = torch.nn.Sequential(
resnet18.conv1,
resnet18.bn1
)
f1 = model.forward(x)
fused = fuse(model[0], model[1])
f2 = fused.forward(x)
d = (f1 - f2).mean().item()
print("error:",d)
因為這麼一個線性操作的轉換,如果有誤差,那才真是見鬼了呢。
關於其他框架,如Keras、Caffe、TensorFlow的操作,與PyTorch基本一個原理,大家可以自己試驗一下。
筆者在測試時候,發現融合掉BN後,會有大概5%的提速,而且還可以減小視訊記憶體消耗,又絲毫不影響誤差,何樂而不為呢。
但是,融合BN僅限於Conv+BN或者是BN+Conv結構,中間不能加非線性層,例如Conv+ReLu+BN那就不行了。當然,一般結構都是Conv+BN+ReLu結構。
相關文章
- 【深度學習筆記】Batch Normalization (BN)深度學習筆記BATORM
- 阿里開源!輕量級深度學習端側推理引擎 MNN阿里深度學習
- 使用代理IP輕鬆獲得韓國IP地址
- 如何輕鬆學習 Kubernetes?
- 怎麼輕鬆學習JavaScriptJavaScript
- 5招輕鬆獲取Mac檔案路徑Mac
- 機器學習的未來——深度特徵融合機器學習特徵
- [譯] 如何輕鬆地在樹莓派上使用深度學習檢測物件樹莓派深度學習物件
- 遷移學習中的BN問題遷移學習
- [原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架原始碼深度學習分散式框架
- 自媒體新手這樣運營讓你輕鬆獲得大票流量
- 5分鐘輕鬆學正規表示式
- 基於CPU的深度學習推理部署優化實踐深度學習優化
- 如何輕鬆學習Python資料分析?Python
- Yii2 - Active Record 輕鬆學習
- 深度學習網路模型的輕量化方法深度學習模型
- CSDN 學習勳章獲得攻略
- 實時深度學習深度學習
- AI 學習之路——輕鬆初探 Python 篇(三)AIPython
- 如何輕鬆利用GPU加速機器學習?GPU機器學習
- 想輕鬆復現深度強化學習論文?看這篇經驗之談強化學習
- substrate輕鬆學系列5:編寫pallet的Rust前置知識Rust
- 有輕功:用3行程式碼讓Python資料處理指令碼獲得4倍提速行程Python指令碼
- 31 天,從淺到深輕鬆學習 KotlinKotlin
- 萬字長文,帶你輕鬆學習 SparkSpark
- 時下火熱的 wGAN 將變革深度學習?這得從源頭講起深度學習
- 輕鬆讓圖片變得清晰Topaz Sharpen AIAI
- 僅1個例子輕鬆學習正規表示式
- 深度學習--實戰 LeNet5深度學習
- 使Mybatis開發變得更加輕鬆的增強工具 — OurbatisMyBatis
- (資料科學學習手札149)用matplotlib輕鬆繪製漂亮的表格資料科學
- 深度學習中的Lipschitz約束:泛化與生成模型深度學習模型
- 鋪天蓋地的炒作下,我依然覺得深度強化學習是浪費時間強化學習
- 輕鬆掌握useAsyncData獲取非同步資料非同步
- (資料科學學習手札90)Python+Kepler.gl輕鬆製作時間輪播地圖資料科學Python地圖
- Vue學習路徑-輕鬆從基礎到實戰Vue
- 機器是如何學習推理的?
- substrate輕鬆學系列1:前言