深度學習推理時融合BN,輕鬆獲得約5%的提速
批歸一化(Batch Normalization)因其可以加速神經網路訓練、使網路訓練更穩定,而且還有一定的正則化效果,所以得到了非常廣泛的應用。但是,在推理階段,BN層一般是可以完全融合到前面的卷積層的,而且絲毫不影響效能。
Batch Normalization是谷歌研究員於2015年提出的一種歸一化方法,其思想非常簡單,一句話概括就是,對一個神經元(或者一個卷積核)的輸出減去統計得到的均值,除以標準差,然後乘以一個可學習的係數,再加上一個偏置,這個過程就完成了。
下面我們簡單介紹一下BN訓練時怎麼做,推理的時候為什麼可以融合,以及怎麼樣融合。
一. BN訓練時如何做
二. BN推理時怎麼做
三. 在框架中如何融合
下面是來自博文[1]中的一個PyTorch例子,將ResNet18中一個卷積+BN層融合後,融合前後輸出的差值為-6.10425390790148e-11,也就是誤差在百億分之一,基本就是0了。
import torch
import torchvision
def fuse(conv, bn):
fused = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
# setting weights
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
# setting bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros( conv.weight.size(0) )
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
torch.sqrt(bn.running_var + bn.eps)
)
fused.bias.copy_( b_conv + b_bn )
return fused
# Testing
# we need to turn off gradient calculation because we didn't write it
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
resnet18 = torchvision.models.resnet18(pretrained=True)
# removing all learning variables, etc
resnet18.eval()
model = torch.nn.Sequential(
resnet18.conv1,
resnet18.bn1
)
f1 = model.forward(x)
fused = fuse(model[0], model[1])
f2 = fused.forward(x)
d = (f1 - f2).mean().item()
print("error:",d)
因為這麼一個線性操作的轉換,如果有誤差,那才真是見鬼了呢。
關於其他框架,如Keras、Caffe、TensorFlow的操作,與PyTorch基本一個原理,大家可以自己試驗一下。
筆者在測試時候,發現融合掉BN後,會有大概5%的提速,而且還可以減小視訊記憶體消耗,又絲毫不影響誤差,何樂而不為呢。
但是,融合BN僅限於Conv+BN或者是BN+Conv結構,中間不能加非線性層,例如Conv+ReLu+BN那就不行了。當然,一般結構都是Conv+BN+ReLu結構。
相關文章
- 【深度學習筆記】Batch Normalization (BN)深度學習筆記BATORM
- 阿里開源!輕量級深度學習端側推理引擎 MNN阿里深度學習
- 使用代理IP輕鬆獲得韓國IP地址
- 輕鬆學習 JavaScript(5):簡化函式提升JavaScript函式
- 怎麼輕鬆學習JavaScriptJavaScript
- 輕鬆學習 JavaScript——第 5 部分:簡化函式提升JavaScript函式
- 研學社·系統組 | 實時深度學習的推理加速和持續訓練深度學習
- 5招輕鬆獲取Mac檔案路徑Mac
- 如何輕鬆學習 Kubernetes?
- [譯] 如何輕鬆地在樹莓派上使用深度學習檢測物件樹莓派深度學習物件
- XML輕鬆學習手冊(5)XML語法之二(轉)XML
- XML輕鬆學習手冊(5)XML語法之四(轉)XML
- 【C#學習筆記】獲得系統時間C#筆記
- 輕鬆學習 JavaScript(8):JavaScript 中的類JavaScript
- [原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架原始碼深度學習分散式框架
- 遷移學習中的BN問題遷移學習
- Yii2 - Active Record 輕鬆學習
- 基於CPU的深度學習推理部署優化實踐深度學習優化
- 數字人輕鬆學Xpresso入門-5
- 實時深度學習深度學習
- 深度學習網路模型的輕量化方法深度學習模型
- CSDN 學習勳章獲得攻略
- 如何輕鬆學習Python資料分析?Python
- 萬字長文,帶你輕鬆學習 SparkSpark
- 輕鬆學習 JavaScript (4):函式中的 arguments 物件JavaScript函式物件
- 輕鬆學習 JavaScript——第 8 部分:JavaScript 中的類JavaScript
- 想輕鬆復現深度強化學習論文?看這篇經驗之談強化學習
- substrate輕鬆學系列5:編寫pallet的Rust前置知識Rust
- 輕鬆學習 JavaScript (2):函式中的 Rest 引數JavaScript函式REST
- AI 學習之路——輕鬆初探 Python 篇(三)AIPython
- 輕鬆學習 JavaScript(1):瞭解 let 語句JavaScript
- 輕鬆學習 JavaScript(6):JavaScript 箭頭函式JavaScript函式
- AI 學習之路——輕鬆初探 Python 篇(一)AIPython
- 一張圖輕鬆解讀《財富》人工智慧萬字長文,關於深度學習的前世今生人工智慧深度學習
- php獲得時間PHP
- 如何判斷深度學習推理是不是真的跑在顯示卡上了深度學習
- 輕鬆學習 JavaScript——第 4 部分:函式中的 arguments 物件JavaScript函式物件
- 輕鬆學習 JavaScript (3):函式中的預設引數JavaScript函式