【論文筆記】張航和李沐等提出:ResNeSt: Split-Attention Networks(ResNet改進版本)

西西嘛呦發表於2020-04-18

github地址:https://github.com/zhanghang1989/ResNeSt

論文地址:https://hangzhang.org/files/resnest.pdf 

 

核心就是:Split-attention blocks

先看一組圖:

ResNeSt在影像分類上中ImageNet資料集上超越了其前輩ResNet、ResNeXt、SENet以及EfficientNet。使用ResNeSt-50為基本骨架的Faster-RCNN比使用ResNet-50的mAP要高出3.08%。使用ResNeSt-50為基本骨架的DeeplabV3比使用ResNet-50的mIOU要高出3.02%。漲點效果非常明顯。

1、提出的動機

他們認為像ResNet等一些基礎卷積神經網路是針對於影像分類而設計的。由於有限的感受野大小以及缺乏跨通道之間的相互作用,這些網路可能不適合於其它的一些領域像目標檢測、影像分割等。這意味著要提高給定計算機視覺任務的效能,需要“網路手術”來修改ResNet,以使其對特定任務更加有效。 例如,某些方法新增了金字塔模組[8,69]或引入了遠端連線[56]或使用跨通道特徵圖注意力[15,65]。 雖然這些方法確實可以提高某些任務的學習效能,但由此而提出了一個問題:我們是否可以建立具有通用改進功能表示的通用骨幹網,從而同時提高跨多個任務的效能?跨通道資訊在下游應用中已被成功使用 [56,64,65],而最近的影像分類網路更多地關注組或深度卷積[27,28,54,60]。 儘管它們在分類任務中具有出色的計算能力和準確性,但是這些模型無法很好地轉移到其他任務,因為它們的孤立表示無法捕獲跨通道之間的關係[27、28]。因此,具有跨通道表示的網路是值得做的。

2、本文的貢獻點

第一個貢獻點:提出了split-attention blocks構造的ResNeSt,與現有的ResNet變體相比,不需要增加額外的計算量。而且ResNeSt可以作為其它任務的骨架。

第二個貢獻點:影像分類和遷移學習應用的大規模基準。 利用ResNeSt主幹的模型能夠在幾個任務上達到最先進的效能,即:影像分類,物件檢測,例項分割和語義分割。 與通過神經架構搜尋生成的最新CNN模型[55]相比,所提出的ResNeSt效能優於所有現有ResNet變體,並且具有相同的計算效率,甚至可以實現更好的速度精度折衷。單個Cascade-RCNN [3]使用ResNeSt-101主幹的模型在MS-COCO例項分割上實現了48.3%的box mAP和41.56%的mask mAP。 單個DeepLabV3 [7]模型同樣使用ResNeSt-101主幹,在ADE20K場景分析驗證集上的mIoU達到46.9%,比以前的最佳結果高出1%mIoU以上。

3、相關工作就不介紹了

4、Split-Attention網路

直接看ResNeSt block:

首先是借鑑了ResNeXt網路的思想,將輸入分為K個,每一個記為Cardinal1-k ,然後又將每個Cardinal拆分成R個,每一個記為Split1-r,所以總共有G=KR個組。

然後是對於每一個Cardinal中具體是什麼樣的:

這裡借鑑了squeeze-and-excitation network(SENet) 中的思想,也就是基於通道的注意力機制,對通道賦予不同的權重以建模通道的重要程度。

對於每一個Cardinal輸入是:

通道權重統計量可以通過全域性平均池化獲得:

用Vk表示攜帶了通道權重後的Cardinal輸出:

那麼最終每個Cardinal的輸出就是:

而其中的是經過了softmax之後計算所得的權重:

如果R=1的話就是對該Cardinal中的所有通道視為一個整體。

接著將每一個Cardinal的輸出拼接起來:

假設每個ResNeSt block的輸出是Y,那麼就有:

其中T表示的是跳躍連線對映。這樣的形式就和ResNet中的殘差塊輸出計算就一致了。

5、殘差網路存在的問題 

(1)殘差網路使用帶步長的卷積,比如3×3卷積來減少影像的空間維度,這樣會損失掉很多空間資訊。對於像目標檢測和分割領域,空間資訊是至關重要的。而且卷積層一般使用0來填充影像邊界,這在遷移到密集預測的其它問題時也不是最佳選擇。因此本文使用的是核大小為3×3的平均池化來減少空間維度

(2)

  • 將殘差網路中的7×7卷積用3個3×3的卷積代替,擁有同樣的感受野。
  • 將跳躍連線中的步長為2的1×1卷積用2×2的平均池化代替。

6、訓練策略

這裡就簡單地列下,相關細節可以去看論文。

(1)大的min batch,使用cosine學習率衰減策略。warm up。BN層引數設定。

(2)標籤平滑

(3)自動增強

(4)mixup訓練

(5)大的切割設定

(6)正則化

6、相關結果 

附錄中還有一些結果,就不再貼了。 

最後是split attention block的實現程式碼,可以結合看一看:

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
from torch.nn.modules.utils import _pair

__all__ = ['SKConv2d']

class DropBlock2D(object):
    def __init__(self, *args, **kwargs):
        raise NotImplementedError

class SplAtConv2d(Module):
    """Split-Attention Conv2d
    """
    def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
                 dilation=(1, 1), groups=1, bias=True,
                 radix=2, reduction_factor=4,
                 rectify=False, rectify_avg=False, norm_layer=None,
                 dropblock_prob=0.0, **kwargs):
        super(SplAtConv2d, self).__init__()
        padding = _pair(padding)
        self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
        self.rectify_avg = rectify_avg
        inter_channels = max(in_channels*radix//reduction_factor, 32)
        self.radix = radix
        self.cardinality = groups
        self.channels = channels
        self.dropblock_prob = dropblock_prob
        if self.rectify:
            from rfconv import RFConv2d
            self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
                                 groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs)
        else:
            self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
                               groups=groups*radix, bias=bias, **kwargs)
        self.use_bn = norm_layer is not None
        self.bn0 = norm_layer(channels*radix)
        self.relu = ReLU(inplace=True)
        self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
        self.bn1 = norm_layer(inter_channels)
        self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality)
        if dropblock_prob > 0.0:
            self.dropblock = DropBlock2D(dropblock_prob, 3)

    def forward(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.bn0(x)
        if self.dropblock_prob > 0.0:
            x = self.dropblock(x)
        x = self.relu(x)

        batch, channel = x.shape[:2]
        if self.radix > 1:
            splited = torch.split(x, channel//self.radix, dim=1)
            gap = sum(splited) 
        else:
            gap = x
        gap = F.adaptive_avg_pool2d(gap, 1)
        gap = self.fc1(gap)

        if self.use_bn:
            gap = self.bn1(gap)
        gap = self.relu(gap)

        atten = self.fc2(gap).view((batch, self.radix, self.channels))
        if self.radix > 1:
            atten = F.softmax(atten, dim=1).view(batch, -1, 1, 1)
        else:
            atten = F.sigmoid(atten, dim=1).view(batch, -1, 1, 1)

        if self.radix > 1:
            atten = torch.split(atten, channel//self.radix, dim=1)
            out = sum([att*split for (att, split) in zip(atten, splited)])
        else:
            out = atten * x
        return out.contiguous()

 

如有錯誤,歡迎指出。

相關文章