深度學習推理時融合BN,輕鬆獲得約5%的提速

ShellCollector發表於2020-12-01
批歸一化(Batch Normalization)因其可以加速神經網路訓練、使網路訓練更穩定,而且還有一定的正則化效果,所以得到了非常廣泛的應用。但是,在推理階段,BN層一般是可以完全融合到前面的卷積層的,而且絲毫不影響效能。
 

Batch Normalization是谷歌研究員於2015年提出的一種歸一化方法,其思想非常簡單,一句話概括就是,對一個神經元(或者一個卷積核)的輸出減去統計得到的均值除以標準差然後乘以一個可學習的係數,再加上一個偏置,這個過程就完成了。

下面我們簡單介紹一下BN訓練時怎麼做,推理的時候為什麼可以融合,以及怎麼樣融合。

一. BN訓練時如何做

1.png2.png

二. BN推理時怎麼做

3.png4.png5.png6.png
 

三. 在框架中如何融合

c0f202aa33a735fe3ab45b06c5fd894.png

下面是來自博文[1]中的一個PyTorch例子,將ResNet18中一個卷積+BN層融合後,融合前後輸出的差值為-6.10425390790148e-11,也就是誤差在百億分之一,基本就是0了。

    import torch
    import torchvision

    def fuse(conv, bn):

        fused = torch.nn.Conv2d(
            conv.in_channels,
            conv.out_channels,
            kernel_size=conv.kernel_size,
            stride=conv.stride,
            padding=conv.padding,
            bias=True
        )

        # setting weights
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
        fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )

        # setting bias
        if conv.bias is not None:
            b_conv = conv.bias
        else:
            b_conv = torch.zeros( conv.weight.size(0) )
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
                              torch.sqrt(bn.running_var + bn.eps)
                            )
        fused.bias.copy_( b_conv + b_bn )

        return fused

    # Testing
    # we need to turn off gradient calculation because we didn't write it
    torch.set_grad_enabled(False)
    x = torch.randn(16, 3, 256, 256)
    resnet18 = torchvision.models.resnet18(pretrained=True)
    # removing all learning variables, etc
    resnet18.eval()
    model = torch.nn.Sequential(
        resnet18.conv1,
        resnet18.bn1
    )
    f1 = model.forward(x)
    fused = fuse(model[0], model[1])
    f2 = fused.forward(x)
    d = (f1 - f2).mean().item()
    print("error:",d)

因為這麼一個線性操作的轉換,如果有誤差,那才真是見鬼了呢。

關於其他框架,如Keras、Caffe、TensorFlow的操作,與PyTorch基本一個原理,大家可以自己試驗一下。

筆者在測試時候,發現融合掉BN後,會有大概5%的提速,而且還可以減小視訊記憶體消耗,又絲毫不影響誤差,何樂而不為呢。

但是,融合BN僅限於Conv+BN或者是BN+Conv結構,中間不能加非線性層,例如Conv+ReLu+BN那就不行了。當然,一般結構都是Conv+BN+ReLu結構。

相關文章