PyTorch ResNet 使用與原始碼解析

張賢同學發表於2020-09-08

本章程式碼:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py

這篇文章首先會簡單介紹一下 PyTorch 中提供的影像分類的網路,然後重點介紹 ResNet 的使用,以及 ResNet 的原始碼。

模型概覽

torchvision.model中,有很多封裝好的模型。

PyTorch ResNet 使用與原始碼解析

可以分類 3 類:
  • 經典網路
    • alexnet
    • vgg
    • resnet
    • inception
    • densenet
    • googlenet
  • 輕量化網路
    • squeezenet
    • mobilenet
    • shufflenetv2
  • 自動神經結構搜尋方法的網路
    • mnasnet

ResNet18 使用

ResNet 18 為例。

首先載入訓練好的模型引數:

resnet18 = models.resnet18()

# 修改全連線層的輸出
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 2)

# 載入模型引數
checkpoint = torch.load(m_path)
resnet18.load_state_dict(checkpoint['model_state_dict'])
然後比較重要的是把模型放到 GPU 上,並且轉換到`eval`模式:
resnet18.to(device)
resnet18.eval()

在 inference 時,主要流程如下:

  • 程式碼要放在with torch.no_grad():下。torch.no_grad()會關閉反向傳播,可以減少記憶體、加快速度。

  • 根據路徑讀取圖片,把圖片轉換為 tensor,然後使用unsqueeze_(0)方法把形狀擴大為 \(B \times C \times H \times W\),再把 tensor 放到 GPU 上 。

  • 模型的輸出資料outputs的形狀是 $1 \times 2$,表示 batch_size 為 1,分類數量為 2。torch.max(outputs,0)是返回outputs每一列最大的元素和索引,torch.max(outputs,1)是返回outputs每一行最大的元素和索引。

    這裡使用_, pred_int = torch.max(outputs.data, 1)返回最大元素的索引,然後根據索引獲得 label:pred_str = classes[int(pred_int)]

關鍵程式碼如下:

    with torch.no_grad():
        for idx, img_name in enumerate(img_names):

            path_img = os.path.join(img_dir, img_name)

            # step 1/4 : path --> img
            img_rgb = Image.open(path_img).convert('RGB')

            # step 2/4 : img --> tensor
            img_tensor = img_transform(img_rgb, inference_transform)
            img_tensor.unsqueeze_(0)
            img_tensor = img_tensor.to(device)

            # step 3/4 : tensor --> vector
            outputs = resnet18(img_tensor)

            # step 4/4 : get label
            _, pred_int = torch.max(outputs.data, 1)
            pred_str = classes[int(pred_int)]

全部程式碼如下所示:

import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
import enviroments
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# config
vis = True
# vis = False
vis_row = 4

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

inference_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

classes = ["ants", "bees"]


def img_transform(img_rgb, transform=None):
    """
    將資料轉換為模型讀取的形式
    :param img_rgb: PIL Image
    :param transform: torchvision.transform
    :return: tensor
    """

    if transform is None:
        raise ValueError("找不到transform!必須有transform對img進行處理")

    img_t = transform(img_rgb)
    return img_t


def get_img_name(img_dir, format="jpg"):
    """
    獲取資料夾下format格式的檔名
    :param img_dir: str
    :param format: str
    :return: list
    """
    file_names = os.listdir(img_dir)
    # 使用 list(filter(lambda())) 篩選出 jpg 字尾的檔案
    img_names = list(filter(lambda x: x.endswith(format), file_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式資料".format(img_dir, format))
    return img_names


def get_model(m_path, vis_model=False):

    resnet18 = models.resnet18()

    # 修改全連線層的輸出
    num_ftrs = resnet18.fc.in_features
    resnet18.fc = nn.Linear(num_ftrs, 2)

    # 載入模型引數
    checkpoint = torch.load(m_path)
    resnet18.load_state_dict(checkpoint['model_state_dict'])


    if vis_model:
        from torchsummary import summary
        summary(resnet18, input_size=(3, 224, 224), device="cpu")

    return resnet18


if __name__ == "__main__":

    img_dir = os.path.join(enviroments.hymenoptera_data_dir,"val/bees")
    model_path = "./checkpoint_14_epoch.pkl"
    time_total = 0
    img_list, img_pred = list(), list()

    # 1. data
    img_names = get_img_name(img_dir)
    num_img = len(img_names)

    # 2. model
    resnet18 = get_model(model_path, True)
    resnet18.to(device)
    resnet18.eval()

    with torch.no_grad():
        for idx, img_name in enumerate(img_names):

            path_img = os.path.join(img_dir, img_name)

            # step 1/4 : path --> img
            img_rgb = Image.open(path_img).convert('RGB')

            # step 2/4 : img --> tensor
            img_tensor = img_transform(img_rgb, inference_transform)
            img_tensor.unsqueeze_(0)
            img_tensor = img_tensor.to(device)

            # step 3/4 : tensor --> vector
            time_tic = time.time()
            outputs = resnet18(img_tensor)
            time_toc = time.time()

            # step 4/4 : visualization
            _, pred_int = torch.max(outputs.data, 1)
            pred_str = classes[int(pred_int)]

            if vis:
                img_list.append(img_rgb)
                img_pred.append(pred_str)

                if (idx+1) % (vis_row*vis_row) == 0 or num_img == idx+1:
                    for i in range(len(img_list)):
                        plt.subplot(vis_row, vis_row, i+1).imshow(img_list[i])
                        plt.title("predict:{}".format(img_pred[i]))
                    plt.show()
                    plt.close()
                    img_list, img_pred = list(), list()

            time_s = time_toc-time_tic
            time_total += time_s

            print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))

    print("\ndevice:{} total time:{:.1f}s mean:{:.3f}s".
          format(device, time_total, time_total/num_img))
    if torch.cuda.is_available():
        print("GPU name:{}".format(torch.cuda.get_device_name()))

總結一下 inference 階段需要注意的事項:

  • 確保 model 處於 eval 狀態,而非 trainning 狀態
  • 設定 torch.no_grad(),減少記憶體消耗,加快運算速度
  • 資料預處理需要保持一致,比如 RGB 或者 rBGR

殘差連線

以 ResNet 為例:

PyTorch ResNet 使用與原始碼解析

一個殘差塊有2條路徑 $F(x)$ 和 $x$,$F(x)$ 路徑擬合殘差,不妨稱之為殘差路徑;$x$ 路徑為`identity mapping`恆等對映,稱之為`shortcut`。圖中的⊕為`element-wise addition`,要求參與運算的 $F(x)$ 和 $x$ 的尺寸要相同。

shortcut 路徑大致可以分成 2 種,取決於殘差路徑是否改變了feature map數量和尺寸。

  • 一種是將輸入x原封不動地輸出。
  • 另一種則需要經過 $1×1$ 卷積來升維或者降取樣,主要作用是將輸出與 \(F(x)\) 路徑的輸出保持shape一致,對網路效能的提升並不明顯。

兩種結構如下圖所示:

PyTorch ResNet 使用與原始碼解析

`ResNet` 中,使用了上面 2 種 `shortcut`。

網路結構

ResNet 有很多變種,包括 ResNet 18ResNet 34ResNet 50ResNet 101ResNet 152,網路結構對比如下:

PyTorch ResNet 使用與原始碼解析

`ResNet` 的各個變種,資料處理大致流程如下:
  • 輸入的圖片形狀是 $3 \times 224 \times 224$。
  • 圖片經過 conv1 層,輸出圖片大小為 $ 64 \times 112 \times 112$。
  • 圖片經過 max pool 層,輸出圖片大小為 \(64 \times 56 \times 56\)
  • 圖片經過 conv2 層,輸出圖片大小為 $ 64 \times 56 \times 56$。(注意,圖片經過這個 layer, 大小是不變的)
  • 圖片經過 conv3 層,輸出圖片大小為 $ 128 \times 28 \times 28$。
  • 圖片經過 conv4 層,輸出圖片大小為 $ 256 \times 14 \times 14$。
  • 圖片經過 conv5 層,輸出圖片大小為 $ 512 \times 7 \times 7$。
  • 圖片經過 avg pool 層,輸出大小為 $ 512 \times 1 \times 1$。
  • 圖片經過 fc 層,輸出維度為 $ num_classes$,表示每個分類的 logits

下面,我們稱每個 conv 層為一個 layer(第一個 conv 層就是一個卷積層,因此第一個 conv 層除外)。

其中 ResNet 18ResNet 34 的每個 layer 由多個 BasicBlock 組成,只是每個 layer 裡堆疊的 BasicBlock 數量不一樣。

ResNet 50ResNet 101ResNet 152 的每個 layer 由多個 Bottleneck 組成,只是每個 layer 裡堆疊的 Bottleneck 數量不一樣。

原始碼分析

我們來看看各個 ResNet 的原始碼,首先從建構函式開始。

建構函式

ResNet 18

resnet18 的建構函式如下。

[2, 2, 2, 2] 表示有 4 個 layer,每個 layer 中有 2 個 BasicBlock

conv1為 1 層,conv2conv3conv4conv5均為 4 層(每個 layer 有 2 個 BasicBlock,每個 BasicBlock 有 2 個卷積層),總共為 16 層,最後一層全連線層,$ 總層數 = 1+ 4 \times 4 + 1 = 18$,依此類推。

def resnet18(pretrained=False, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)

ResNet 34

resnet 34 的建構函式如下。

[3, 4, 6, 3] 表示有 4 個 layer,每個 layerBasicBlock 數量分別為 3, 4, 6, 3。

def resnet34(pretrained=False, progress=True, **kwargs):
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)

ResNet 50

resnet 34 的建構函式如下。

[3, 4, 6, 3] 表示有 4 個 layer,每個 layerBottleneck 數量分別為 3, 4, 6, 3。

def resnet50(pretrained=False, progress=True, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)

依此類推,ResNet 101ResNet 152 也是由多個 layer 組成的。

_resnet()

上面所有的建構函式中,都呼叫了 _resnet() 方法來建立網路,下面來看看 _resnet() 方法。

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    # 載入預訓練好的模型引數
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

可以看到,在 _resnet() 方法中,又呼叫了 ResNet() 方法建立模型,然後載入訓練好的模型引數。

ResNet()

首先來看 ResNet() 方法的建構函式。

建構函式

建構函式的重要引數如下:

  • block:每個 layer 裡面使用的 block,可以是 BasicBlock Bottleneck
  • num_classes:分類數量,用於構建最後的全連線層。
  • layers:一個 list,表示每個 layerblock 的數量。

建構函式的主要流程如下:

  • 判斷是否傳入 norm_layer,沒有傳入,則使用 BatchNorm2d

  • 判斷是否傳入孔洞卷積引數 replace_stride_with_dilation,如果不指定,則賦值為 [False, False, False],表示不使用孔洞卷積。

  • 讀取分組卷積的引數 groupswidth_per_group

  • 然後真正開始構造網路。

  • conv1 層的結構是 Conv2d -> norm_layer -> ReLU

    self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
    self.bn1 = norm_layer(self.inplanes)
    self.relu = nn.ReLU(inplace=True)
    
  • conv2 層的程式碼如下,對應於 layer1,這個 layer 的引數沒有指定 stride,預設 stride=1,因此這個 layer 不會改變圖片大小:

    self.layer1 = self._make_layer(block, 64, layers[0])
    
  • conv3 層的程式碼如下,對應於 layer2(注意這個 layer 指定 stride=2,會降取樣,詳情看下面 _make_layer 的講解):

    self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
    
  • conv4 層的程式碼如下,對應於 layer3(注意這個 layer 指定 stride=2,會降取樣,詳情看下面 _make_layer 的講解):

    self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
    dilate=replace_stride_with_dilation[1])
    
  • conv5 層的程式碼如下,對應於 layer4(注意這個 layer 指定 stride=2,會降取樣,詳情看下面 _make_layer 的講解):

    self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
    dilate=replace_stride_with_dilation[2])
    
  • 接著是 AdaptiveAvgPool2d 層和 fc 層。

  • 最後是網路引數的初始:

    • 卷積層採用 kaiming_normal_() 初始化方法。
    • bn 層和 GroupNorm 層初始化為 weight=1bias=0
    • 其中每個 BasicBlockBottleneck 的最後一層 bnweight=0,可以提升準確率 0.2~0.3%。

完整的建構函式程式碼如下:

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        # 使用 bn 層
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        # 對應於 conv1
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        # 對應於 conv2
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        # 對應於 conv3
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        對應於 conv4
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        對應於 conv5
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

forward()

ResNet 中,網路經過層層封裝,因此forward() 方法非常簡潔。

資料變換大致流程如下:

  • 輸入的圖片形狀是 $3 \times 224 \times 224$。
  • 圖片經過 conv1 層,輸出圖片大小為 $ 64 \times 112 \times 112$。
  • 圖片經過 max pool 層,輸出圖片大小為 \(64 \times 56 \times 56\)
  • 對於 ResNet 18ResNet 34 (使用 BasicBlock):
    • 圖片經過 conv2 層,對應於 layer1,輸出圖片大小為 $ 64 \times 56 \times 56$。(注意,圖片經過這個 layer, 大小是不變的)
    • 圖片經過 conv3 層,對應於 layer2,輸出圖片大小為 $ 128 \times 28 \times 28$。
    • 圖片經過 conv4 層,對應於 layer3,輸出圖片大小為 $ 256 \times 14 \times 14$。
    • 圖片經過 conv5 層,對應於 layer4,輸出圖片大小為 $ 512 \times 7 \times 7$。
    • 圖片經過 avg pool 層,輸出大小為 $ 512 \times 1 \times 1$。
  • 對於 ResNet 50ResNet 101ResNet 152(使用 Bottleneck):
    • 圖片經過 conv2 層,對應於 layer1,輸出圖片大小為 $ 256 \times 56 \times 56$。(注意,圖片經過這個 layer, 大小是不變的)
    • 圖片經過 conv3 層,對應於 layer2,輸出圖片大小為 $ 512 \times 28 \times 28$。
    • 圖片經過 conv4 層,對應於 layer3,輸出圖片大小為 $ 1024 \times 14 \times 14$。
    • 圖片經過 conv5 層,對應於 layer4,輸出圖片大小為 $ 2048 \times 7 \times 7$。
    • 圖片經過 avg pool 層,輸出大小為 $ 2048 \times 1 \times 1$。
  • 圖片經過 fc 層,輸出維度為 $ num_classes$,表示每個分類的 logits
    def _forward_impl(self, x):
        # See note [TorchScript super()]

        # conv1
        # x: [3, 224, 224] -> [64, 112, 112]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # conv2
        # x: [64, 112, 112] -> [64, 56, 56]
        x = self.maxpool(x)

		# x: [64, 56, 56] -> [64, 56, 56]
		# x 經過第一個 layer, 大小是不變的
        x = self.layer1(x)

        # conv3
        # x: [64, 56, 56] -> [128, 28, 28]
        x = self.layer2(x)

        # conv4
        # x: [128, 28, 28] -> [256, 14, 14]
        x = self.layer3(x)

        # conv5
        # x: [256, 14, 14] -> [512, 7, 7]
        x = self.layer4(x)

		# x: [512, 7, 7] -> [512, 1, 1]
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

在建構函式中可以看到,上面每個 layer 都是使用 _make_layer() 方法來建立層的,下面來看下 _make_layer() 方法。

_make_layer()

_make_layer()方法的引數如下:

  • block:每個 layer 裡面使用的 block,可以是 BasicBlockBottleneck
  • planes:輸出的通道數
  • blocks:一個整數,表示該層 layer 有多少個 block
  • stride:第一個 block 的卷積層的 stride,預設為 1。注意,只有在每個 layer 的第一個 block 的第一個卷積層使用該引數。
  • dilate:是否使用孔洞卷積。

主要流程如下:

  • 判斷孔洞卷積,計算 previous_dilation 引數。

  • 判斷 stride 是否為 1,輸入通道和輸出通道是否相等。如果這兩個條件都不成立,那麼表明需要建立一個 1 X 1 的卷積層,來改變通道數和改變圖片大小。具體是建立 downsample 層,包括 conv1x1 -> norm_layer

  • 建立第一個 block,把 downsample 傳給 block 作為降取樣的層,並且 stride 也使用傳入的 stride(stride=2)。後面我們會分析 downsample 層在 BasicBlockBottleneck 中,具體是怎麼用的

  • 改變通道數self.inplanes = planes * block.expansion

    • BasicBlock 裡,expansion=1,因此這一步不會改變通道數
    • Bottleneck 裡,expansion=4,因此這一步會改變通道數
  • 圖片經過第一個 block後,就會改變通道數和圖片大小。接下來 for 迴圈新增剩下的 block。從第 2 個 block 起,輸入和輸出通道數是相等的,因此就不用傳入 downsamplestride(那麼 blockstride 預設使用 1,下面我們會分析 BasicBlockBottleneck 的原始碼)。

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        # 首先判斷 stride 是否為1,輸入通道和輸出通道是否相等。不相等則使用 1 X 1 的卷積改變大小和通道
        #作為 downsample
        # 在 Resnet 中,每層 layer 傳入的 stride =2
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        # 然後新增第一個 basic block,把 downsample 傳給 BasicBlock 作為降取樣的層。
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        # 修改輸出的通道數
        self.inplanes = planes * block.expansion
        # 繼續新增這個 layer 裡接下來的 BasicBlock
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

下面來看 BasicBlockBottleneck 的原始碼。

BasicBlock

建構函式

BasicBlock 建構函式的主要引數如下:

  • inplanes:輸入通道數。

  • planes:輸出通道數。

  • stride:第一個卷積層的 stride

  • downsample:從 layer 中傳入的 downsample 層。

  • groups:分組卷積的分組數,使用 1

  • base_width:每組卷積的通道數,使用 64

  • dilation:孔洞卷積,為 1,表示不使用 孔洞卷積

主要流程如下:

  • 首先判斷是否傳入了 norm_layer 層,如果沒有,則使用 BatchNorm2d
  • 校驗引數:groups == 1base_width == 64dilation == 1。也就是說,在 BasicBlock 中,不使用孔洞卷積和分組卷積。
  • 定義第 1 組 conv3x3 -> norm_layer -> relu,這裡使用傳入的 strideinplanes。(如果是 layer2layer3layer4 裡的第一個 BasicBlock,那麼 stride=2,這裡會降取樣和改變通道數)。
  • 定義第 2 組 conv3x3 -> norm_layer -> relu,這裡不使用傳入的 stride (預設為 1),輸入通道數和輸出通道數使用planes,也就是不需要降取樣和改變通道數
class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

forward()

forward() 方法的主要流程如下:

  • x 賦值給 identity,用於後面的 shortcut 連線。
  • x 經過第 1 組 conv3x3 -> norm_layer -> relu,如果是 layer2layer3layer4 裡的第一個 BasicBlock,那麼 stride=2,第一個卷積層會降取樣。
  • x 經過第 1 組 conv3x3 -> norm_layer,得到 out
  • 如果是 layer2layer3layer4 裡的第一個 BasicBlock,那麼 downsample 不為空,會經過 downsample 層,得到 identity
  • 最後將 identityout 相加,經過 relu ,得到輸出。

注意,2 個卷積層都需要經過 relu 層,但它們使用的是同一個 relu 層。

    def forward(self, x):
        identity = x
		# 如果是 layer2,layer3,layer4 裡的第一個 BasicBlock,第一個卷積層會降取樣
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

Bottleneck

建構函式

引數如下:

  • inplanes:輸入通道數。
  • planes:輸出通道數。
  • stride:第一個卷積層的 stride
  • downsample:從 layer 中傳入的 downsample 層。
  • groups:分組卷積的分組數,使用 1
  • base_width:每組卷積的通道數,使用 64
  • dilation:孔洞卷積,為 1,表示不使用 孔洞卷積

主要流程如下:

  • 首先判斷是否傳入了 norm_layer 層,如果沒有,則使用 BatchNorm2d
  • 計算 width,等於傳入的 planes,用於中間的 \(3 \times 3\) 卷積。
  • 定義第 1 組 conv1x1 -> norm_layer,這裡不使用傳入的 stride,使用 width,作用是進行降維,減少通道數。
  • 定義第 2 組 conv3x3 -> norm_layer,這裡使用傳入的 stride,輸入通道數和輸出通道數使用width。(如果是 layer2layer3layer4 裡的第一個 Bottleneck,那麼 stride=2,這裡會降取樣)。
  • 定義第 3 組 conv1x1 -> norm_layer,這裡不使用傳入的 stride,使用 planes * self.expansion,作用是進行升維,增加通道數。
class Bottleneck(nn.Module):
    expansion = 4
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        # base_width = 64
        # groups =1
        # width = planes
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        # 1x1 的卷積是為了降維,減少通道數
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        # 3x3 的卷積是為了改變圖片大小,不改變通道數
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        # 1x1 的卷積是為了升維,增加通道數,增加到 planes * 4
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

forward()

forward() 方法的主要流程如下:

  • x 賦值給 identity,用於後面的 shortcut 連線。
  • x 經過第 1 組 conv1x1 -> norm_layer -> relu,作用是進行降維,減少通道數。
  • x 經過第 2 組 conv3x3 -> norm_layer -> relu。如果是 layer2layer3layer4 裡的第一個 Bottleneck,那麼 stride=2,第一個卷積層會降取樣。
  • x 經過第 1 組 conv1x1 -> norm_layer -> relu,作用是進行降維,減少通道數。
  • 如果是 layer2layer3layer4 裡的第一個 Bottleneck,那麼 downsample 不為空,會經過 downsample 層,得到 identity
  • 最後將 identityout 相加,經過 relu ,得到輸出。

注意,3 個卷積層都需要經過 relu 層,但它們使用的是同一個 relu 層。

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

總結

最後,總結一下。

  • BasicBlock 中有 1 個 $3 \times 3 $ 卷積層,如果是 layer 的第一個 BasicBlock,那麼第一個卷積層的 stride=2,作用是進行降取樣。
  • Bottleneck 中有 2 個 $1 \times 1 $ 卷積層, 1 個 $3 \times 3 $ 卷積層。先經過第 1 個 $1 \times 1 $ 卷積層,進行降維,然後經過 $3 \times 3 $ 卷積層(如果是 layer 的第一個 Bottleneck,那麼 $3 \times 3 $ 卷積層的 stride=2,作用是進行降取樣),最後經過 $1 \times 1 $ 卷積層,進行升維 。

ResNet 18 圖解

layer1

下面是 ResNet 18 ,使用的是 BasicBlocklayer1,特點是沒有進行降取樣,卷積層的 stride = 1,不會降取樣。在進行 shortcut 連線時,也沒有經過 downsample 層。

PyTorch ResNet 使用與原始碼解析

layer2,layer3,layer4

layer2layer3layer4 的結構圖如下,每個 layer 包含 2 個 BasicBlock,但是第 1 個 BasicBlock 的第 1 個卷積層的 stride = 2,會進行降取樣。在進行 shortcut 連線時,會經過 downsample 層,進行降取樣和降維

PyTorch ResNet 使用與原始碼解析

ResNet 50 圖解

layer1

layer1 中,首先第一個 Bottleneck 只會進行升維,不會降取樣。shortcut 連線前,會經過 downsample 層升維處理。第二個 Bottleneckshortcut 連線不會經過 downsample 層。

PyTorch ResNet 使用與原始碼解析

layer2,layer3,layer4

layer2layer3layer4 的結構圖如下,每個 layer 包含多個 Bottleneck,但是第 1 個 Bottleneck\(3 \times 3\) 卷積層的 stride = 2,會進行降取樣。在進行 shortcut 連線時,會經過 downsample 層,進行降取樣和降維

PyTorch ResNet 使用與原始碼解析


如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章。

相關文章