本章程式碼:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py
這篇文章首先會簡單介紹一下 PyTorch
中提供的影像分類的網路,然後重點介紹 ResNet
的使用,以及 ResNet
的原始碼。
模型概覽
在torchvision.model
中,有很多封裝好的模型。
可以分類 3 類:
- 經典網路
- alexnet
- vgg
- resnet
- inception
- densenet
- googlenet
- 輕量化網路
- squeezenet
- mobilenet
- shufflenetv2
- 自動神經結構搜尋方法的網路
- mnasnet
ResNet18 使用
以 ResNet 18
為例。
首先載入訓練好的模型引數:
resnet18 = models.resnet18()
# 修改全連線層的輸出
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 2)
# 載入模型引數
checkpoint = torch.load(m_path)
resnet18.load_state_dict(checkpoint['model_state_dict'])
然後比較重要的是把模型放到 GPU 上,並且轉換到`eval`模式:
resnet18.to(device)
resnet18.eval()
在 inference 時,主要流程如下:
程式碼要放在
with torch.no_grad():
下。torch.no_grad()
會關閉反向傳播,可以減少記憶體、加快速度。根據路徑讀取圖片,把圖片轉換為 tensor,然後使用
unsqueeze_(0)
方法把形狀擴大為 \(B \times C \times H \times W\),再把 tensor 放到 GPU 上 。模型的輸出資料
outputs
的形狀是 $1 \times 2$,表示batch_size
為 1,分類數量為 2。torch.max(outputs,0)
是返回outputs
中每一列最大的元素和索引,torch.max(outputs,1)
是返回outputs
中每一行最大的元素和索引。這裡使用
_, pred_int = torch.max(outputs.data, 1)
返回最大元素的索引,然後根據索引獲得 label:pred_str = classes[int(pred_int)]
。
關鍵程式碼如下:
with torch.no_grad():
for idx, img_name in enumerate(img_names):
path_img = os.path.join(img_dir, img_name)
# step 1/4 : path --> img
img_rgb = Image.open(path_img).convert('RGB')
# step 2/4 : img --> tensor
img_tensor = img_transform(img_rgb, inference_transform)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)
# step 3/4 : tensor --> vector
outputs = resnet18(img_tensor)
# step 4/4 : get label
_, pred_int = torch.max(outputs.data, 1)
pred_str = classes[int(pred_int)]
全部程式碼如下所示:
import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
import enviroments
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
# config
vis = True
# vis = False
vis_row = 4
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
inference_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
classes = ["ants", "bees"]
def img_transform(img_rgb, transform=None):
"""
將資料轉換為模型讀取的形式
:param img_rgb: PIL Image
:param transform: torchvision.transform
:return: tensor
"""
if transform is None:
raise ValueError("找不到transform!必須有transform對img進行處理")
img_t = transform(img_rgb)
return img_t
def get_img_name(img_dir, format="jpg"):
"""
獲取資料夾下format格式的檔名
:param img_dir: str
:param format: str
:return: list
"""
file_names = os.listdir(img_dir)
# 使用 list(filter(lambda())) 篩選出 jpg 字尾的檔案
img_names = list(filter(lambda x: x.endswith(format), file_names))
if len(img_names) < 1:
raise ValueError("{}下找不到{}格式資料".format(img_dir, format))
return img_names
def get_model(m_path, vis_model=False):
resnet18 = models.resnet18()
# 修改全連線層的輸出
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 2)
# 載入模型引數
checkpoint = torch.load(m_path)
resnet18.load_state_dict(checkpoint['model_state_dict'])
if vis_model:
from torchsummary import summary
summary(resnet18, input_size=(3, 224, 224), device="cpu")
return resnet18
if __name__ == "__main__":
img_dir = os.path.join(enviroments.hymenoptera_data_dir,"val/bees")
model_path = "./checkpoint_14_epoch.pkl"
time_total = 0
img_list, img_pred = list(), list()
# 1. data
img_names = get_img_name(img_dir)
num_img = len(img_names)
# 2. model
resnet18 = get_model(model_path, True)
resnet18.to(device)
resnet18.eval()
with torch.no_grad():
for idx, img_name in enumerate(img_names):
path_img = os.path.join(img_dir, img_name)
# step 1/4 : path --> img
img_rgb = Image.open(path_img).convert('RGB')
# step 2/4 : img --> tensor
img_tensor = img_transform(img_rgb, inference_transform)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)
# step 3/4 : tensor --> vector
time_tic = time.time()
outputs = resnet18(img_tensor)
time_toc = time.time()
# step 4/4 : visualization
_, pred_int = torch.max(outputs.data, 1)
pred_str = classes[int(pred_int)]
if vis:
img_list.append(img_rgb)
img_pred.append(pred_str)
if (idx+1) % (vis_row*vis_row) == 0 or num_img == idx+1:
for i in range(len(img_list)):
plt.subplot(vis_row, vis_row, i+1).imshow(img_list[i])
plt.title("predict:{}".format(img_pred[i]))
plt.show()
plt.close()
img_list, img_pred = list(), list()
time_s = time_toc-time_tic
time_total += time_s
print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))
print("\ndevice:{} total time:{:.1f}s mean:{:.3f}s".
format(device, time_total, time_total/num_img))
if torch.cuda.is_available():
print("GPU name:{}".format(torch.cuda.get_device_name()))
總結一下 inference 階段需要注意的事項:
- 確保 model 處於 eval 狀態,而非 trainning 狀態
- 設定 torch.no_grad(),減少記憶體消耗,加快運算速度
- 資料預處理需要保持一致,比如 RGB 或者 rBGR
殘差連線
以 ResNet 為例:
一個殘差塊有2條路徑 $F(x)$ 和 $x$,$F(x)$ 路徑擬合殘差,不妨稱之為殘差路徑;$x$ 路徑為`identity mapping`恆等對映,稱之為`shortcut`。圖中的⊕為`element-wise addition`,要求參與運算的 $F(x)$ 和 $x$ 的尺寸要相同。
shortcut
路徑大致可以分成 2 種,取決於殘差路徑是否改變了feature map
數量和尺寸。
- 一種是將輸入
x
原封不動地輸出。 - 另一種則需要經過 $1×1$ 卷積來升維或者降取樣,主要作用是將輸出與 \(F(x)\) 路徑的輸出保持
shape
一致,對網路效能的提升並不明顯。
兩種結構如下圖所示:
`ResNet` 中,使用了上面 2 種 `shortcut`。
網路結構
ResNet 有很多變種,包括 ResNet 18
、ResNet 34
、ResNet 50
、ResNet 101
、ResNet 152
,網路結構對比如下:
`ResNet` 的各個變種,資料處理大致流程如下:
- 輸入的圖片形狀是 $3 \times 224 \times 224$。
- 圖片經過
conv1
層,輸出圖片大小為 $ 64 \times 112 \times 112$。 - 圖片經過
max pool
層,輸出圖片大小為 \(64 \times 56 \times 56\)。 - 圖片經過
conv2
層,輸出圖片大小為 $ 64 \times 56 \times 56$。(注意,圖片經過這個layer
, 大小是不變的) - 圖片經過
conv3
層,輸出圖片大小為 $ 128 \times 28 \times 28$。 - 圖片經過
conv4
層,輸出圖片大小為 $ 256 \times 14 \times 14$。 - 圖片經過
conv5
層,輸出圖片大小為 $ 512 \times 7 \times 7$。 - 圖片經過
avg pool
層,輸出大小為 $ 512 \times 1 \times 1$。 - 圖片經過
fc
層,輸出維度為 $ num_classes$,表示每個分類的logits
。
下面,我們稱每個 conv
層為一個 layer
(第一個 conv
層就是一個卷積層,因此第一個 conv
層除外)。
其中 ResNet 18
、ResNet 34
的每個 layer
由多個 BasicBlock
組成,只是每個 layer
裡堆疊的 BasicBlock
數量不一樣。
而 ResNet 50
、ResNet 101
、ResNet 152
的每個 layer
由多個 Bottleneck
組成,只是每個 layer
裡堆疊的 Bottleneck
數量不一樣。
原始碼分析
我們來看看各個 ResNet
的原始碼,首先從建構函式開始。
建構函式
ResNet 18
resnet18
的建構函式如下。
[2, 2, 2, 2]
表示有 4 個 layer
,每個 layer 中有 2 個 BasicBlock
。
conv1
為 1 層,conv2
、conv3
、conv4
、conv5
均為 4 層(每個 layer
有 2 個 BasicBlock
,每個 BasicBlock
有 2 個卷積層),總共為 16 層,最後一層全連線層,$ 總層數 = 1+ 4 \times 4 + 1 = 18$,依此類推。
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
ResNet 34
resnet 34
的建構函式如下。
[3, 4, 6, 3]
表示有 4 個 layer
,每個 layer
的 BasicBlock
數量分別為 3, 4, 6, 3。
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
ResNet 50
resnet 34
的建構函式如下。
[3, 4, 6, 3]
表示有 4 個 layer
,每個 layer
的 Bottleneck
數量分別為 3, 4, 6, 3。
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
依此類推,ResNet 101
和 ResNet 152
也是由多個 layer
組成的。
_resnet()
上面所有的建構函式中,都呼叫了 _resnet()
方法來建立網路,下面來看看 _resnet()
方法。
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
# 載入預訓練好的模型引數
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
可以看到,在 _resnet()
方法中,又呼叫了 ResNet()
方法建立模型,然後載入訓練好的模型引數。
ResNet()
首先來看 ResNet()
方法的建構函式。
建構函式
建構函式的重要引數如下:
- block:每個
layer
裡面使用的block
,可以是BasicBlock
Bottleneck
。 - num_classes:分類數量,用於構建最後的全連線層。
- layers:一個 list,表示每個
layer
中block
的數量。
建構函式的主要流程如下:
判斷是否傳入
norm_layer
,沒有傳入,則使用BatchNorm2d
。判斷是否傳入孔洞卷積引數
replace_stride_with_dilation
,如果不指定,則賦值為[False, False, False]
,表示不使用孔洞卷積。讀取分組卷積的引數
groups
,width_per_group
。然後真正開始構造網路。
conv1
層的結構是Conv2d -> norm_layer -> ReLU
。self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True)
conv2
層的程式碼如下,對應於layer1
,這個layer
的引數沒有指定stride
,預設stride=1
,因此這個layer
不會改變圖片大小:self.layer1 = self._make_layer(block, 64, layers[0])
conv3
層的程式碼如下,對應於layer2
(注意這個layer
指定stride=2
,會降取樣,詳情看下面_make_layer
的講解):self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
conv4
層的程式碼如下,對應於layer3
(注意這個layer
指定stride=2
,會降取樣,詳情看下面_make_layer
的講解):self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
conv5
層的程式碼如下,對應於layer4
(注意這個layer
指定stride=2
,會降取樣,詳情看下面_make_layer
的講解):self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
接著是
AdaptiveAvgPool2d
層和fc
層。最後是網路引數的初始:
- 卷積層採用
kaiming_normal_()
初始化方法。 bn
層和GroupNorm
層初始化為weight=1
,bias=0
。- 其中每個
BasicBlock
和Bottleneck
的最後一層bn
的weight=0
,可以提升準確率 0.2~0.3%。
- 卷積層採用
完整的建構函式程式碼如下:
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
# 使用 bn 層
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
# 對應於 conv1
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
# 對應於 conv2
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
# 對應於 conv3
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
對應於 conv4
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
對應於 conv5
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
forward()
在 ResNet
中,網路經過層層封裝,因此forward()
方法非常簡潔。
資料變換大致流程如下:
- 輸入的圖片形狀是 $3 \times 224 \times 224$。
- 圖片經過
conv1
層,輸出圖片大小為 $ 64 \times 112 \times 112$。 - 圖片經過
max pool
層,輸出圖片大小為 \(64 \times 56 \times 56\)。 - 對於
ResNet 18
、ResNet 34
(使用BasicBlock
):- 圖片經過
conv2
層,對應於layer1
,輸出圖片大小為 $ 64 \times 56 \times 56$。(注意,圖片經過這個layer
, 大小是不變的) - 圖片經過
conv3
層,對應於layer2
,輸出圖片大小為 $ 128 \times 28 \times 28$。 - 圖片經過
conv4
層,對應於layer3
,輸出圖片大小為 $ 256 \times 14 \times 14$。 - 圖片經過
conv5
層,對應於layer4
,輸出圖片大小為 $ 512 \times 7 \times 7$。 - 圖片經過
avg pool
層,輸出大小為 $ 512 \times 1 \times 1$。
- 圖片經過
- 對於
ResNet 50
、ResNet 101
、ResNet 152
(使用Bottleneck
):- 圖片經過
conv2
層,對應於layer1
,輸出圖片大小為 $ 256 \times 56 \times 56$。(注意,圖片經過這個layer
, 大小是不變的) - 圖片經過
conv3
層,對應於layer2
,輸出圖片大小為 $ 512 \times 28 \times 28$。 - 圖片經過
conv4
層,對應於layer3
,輸出圖片大小為 $ 1024 \times 14 \times 14$。 - 圖片經過
conv5
層,對應於layer4
,輸出圖片大小為 $ 2048 \times 7 \times 7$。 - 圖片經過
avg pool
層,輸出大小為 $ 2048 \times 1 \times 1$。
- 圖片經過
- 圖片經過
fc
層,輸出維度為 $ num_classes$,表示每個分類的logits
。
def _forward_impl(self, x):
# See note [TorchScript super()]
# conv1
# x: [3, 224, 224] -> [64, 112, 112]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
# conv2
# x: [64, 112, 112] -> [64, 56, 56]
x = self.maxpool(x)
# x: [64, 56, 56] -> [64, 56, 56]
# x 經過第一個 layer, 大小是不變的
x = self.layer1(x)
# conv3
# x: [64, 56, 56] -> [128, 28, 28]
x = self.layer2(x)
# conv4
# x: [128, 28, 28] -> [256, 14, 14]
x = self.layer3(x)
# conv5
# x: [256, 14, 14] -> [512, 7, 7]
x = self.layer4(x)
# x: [512, 7, 7] -> [512, 1, 1]
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
在建構函式中可以看到,上面每個 layer
都是使用 _make_layer()
方法來建立層的,下面來看下 _make_layer()
方法。
_make_layer()
_make_layer()
方法的引數如下:
- block:每個
layer
裡面使用的block
,可以是BasicBlock
,Bottleneck
。 - planes:輸出的通道數
- blocks:一個整數,表示該層
layer
有多少個block
。 - stride:第一個
block
的卷積層的stride
,預設為 1。注意,只有在每個layer
的第一個block
的第一個卷積層使用該引數。 - dilate:是否使用孔洞卷積。
主要流程如下:
判斷孔洞卷積,計算
previous_dilation
引數。判斷
stride
是否為 1,輸入通道和輸出通道是否相等。如果這兩個條件都不成立,那麼表明需要建立一個 1 X 1 的卷積層,來改變通道數和改變圖片大小。具體是建立downsample
層,包括conv1x1 -> norm_layer
。建立第一個
block
,把downsample
傳給block
作為降取樣的層,並且stride
也使用傳入的stride
(stride=2)。後面我們會分析downsample
層在BasicBlock
和Bottleneck
中,具體是怎麼用的。改變通道數
self.inplanes = planes * block.expansion
。- 在
BasicBlock
裡,expansion=1
,因此這一步不會改變通道數。 - 在
Bottleneck
裡,expansion=4
,因此這一步會改變通道數。
- 在
圖片經過第一個
block
後,就會改變通道數和圖片大小。接下來 for 迴圈新增剩下的block
。從第 2 個block
起,輸入和輸出通道數是相等的,因此就不用傳入downsample
和stride
(那麼block
的stride
預設使用 1,下面我們會分析BasicBlock
和Bottleneck
的原始碼)。
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
# 首先判斷 stride 是否為1,輸入通道和輸出通道是否相等。不相等則使用 1 X 1 的卷積改變大小和通道
#作為 downsample
# 在 Resnet 中,每層 layer 傳入的 stride =2
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
# 然後新增第一個 basic block,把 downsample 傳給 BasicBlock 作為降取樣的層。
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
# 修改輸出的通道數
self.inplanes = planes * block.expansion
# 繼續新增這個 layer 裡接下來的 BasicBlock
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
下面來看 BasicBlock
和 Bottleneck
的原始碼。
BasicBlock
建構函式
BasicBlock
建構函式的主要引數如下:
inplanes:輸入通道數。
planes:輸出通道數。
stride:第一個卷積層的
stride
。downsample:從
layer
中傳入的downsample
層。groups:分組卷積的分組數,使用 1
base_width:每組卷積的通道數,使用 64
dilation:孔洞卷積,為 1,表示不使用 孔洞卷積
主要流程如下:
- 首先判斷是否傳入了
norm_layer
層,如果沒有,則使用BatchNorm2d
。 - 校驗引數:
groups == 1
,base_width == 64
,dilation == 1
。也就是說,在BasicBlock
中,不使用孔洞卷積和分組卷積。 - 定義第 1 組
conv3x3 -> norm_layer -> relu
,這裡使用傳入的stride
和inplanes
。(如果是layer2
,layer3
,layer4
裡的第一個BasicBlock
,那麼stride=2
,這裡會降取樣和改變通道數)。 - 定義第 2 組
conv3x3 -> norm_layer -> relu
,這裡不使用傳入的stride
(預設為 1),輸入通道數和輸出通道數使用planes
,也就是不需要降取樣和改變通道數。
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
forward()
forward()
方法的主要流程如下:
x
賦值給identity
,用於後面的shortcut
連線。x
經過第 1 組conv3x3 -> norm_layer -> relu
,如果是layer2
,layer3
,layer4
裡的第一個BasicBlock
,那麼stride=2
,第一個卷積層會降取樣。x
經過第 1 組conv3x3 -> norm_layer
,得到out
。- 如果是
layer2
,layer3
,layer4
裡的第一個BasicBlock
,那麼downsample
不為空,會經過downsample
層,得到identity
。 - 最後將
identity
和out
相加,經過relu
,得到輸出。
注意,2 個卷積層都需要經過
relu
層,但它們使用的是同一個relu
層。
def forward(self, x):
identity = x
# 如果是 layer2,layer3,layer4 裡的第一個 BasicBlock,第一個卷積層會降取樣
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
Bottleneck
建構函式
引數如下:
- inplanes:輸入通道數。
- planes:輸出通道數。
- stride:第一個卷積層的
stride
。 - downsample:從
layer
中傳入的downsample
層。 - groups:分組卷積的分組數,使用 1
- base_width:每組卷積的通道數,使用 64
- dilation:孔洞卷積,為 1,表示不使用 孔洞卷積
主要流程如下:
- 首先判斷是否傳入了
norm_layer
層,如果沒有,則使用BatchNorm2d
。 - 計算
width
,等於傳入的planes
,用於中間的 \(3 \times 3\) 卷積。 - 定義第 1 組
conv1x1 -> norm_layer
,這裡不使用傳入的stride
,使用width
,作用是進行降維,減少通道數。 - 定義第 2 組
conv3x3 -> norm_layer
,這裡使用傳入的stride
,輸入通道數和輸出通道數使用width
。(如果是layer2
,layer3
,layer4
裡的第一個Bottleneck
,那麼stride=2
,這裡會降取樣)。 - 定義第 3 組
conv1x1 -> norm_layer
,這裡不使用傳入的stride
,使用planes * self.expansion
,作用是進行升維,增加通道數。
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
# base_width = 64
# groups =1
# width = planes
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
# 1x1 的卷積是為了降維,減少通道數
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
# 3x3 的卷積是為了改變圖片大小,不改變通道數
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
# 1x1 的卷積是為了升維,增加通道數,增加到 planes * 4
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
forward()
forward()
方法的主要流程如下:
x
賦值給identity
,用於後面的shortcut
連線。x
經過第 1 組conv1x1 -> norm_layer -> relu
,作用是進行降維,減少通道數。x
經過第 2 組conv3x3 -> norm_layer -> relu
。如果是layer2
,layer3
,layer4
裡的第一個Bottleneck
,那麼stride=2
,第一個卷積層會降取樣。x
經過第 1 組conv1x1 -> norm_layer -> relu
,作用是進行降維,減少通道數。- 如果是
layer2
,layer3
,layer4
裡的第一個Bottleneck
,那麼downsample
不為空,會經過downsample
層,得到identity
。 - 最後將
identity
和out
相加,經過relu
,得到輸出。
注意,3 個卷積層都需要經過
relu
層,但它們使用的是同一個relu
層。
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
總結
最後,總結一下。
BasicBlock
中有 1 個 $3 \times 3 $ 卷積層,如果是layer
的第一個BasicBlock
,那麼第一個卷積層的stride=2
,作用是進行降取樣。Bottleneck
中有 2 個 $1 \times 1 $ 卷積層, 1 個 $3 \times 3 $ 卷積層。先經過第 1 個 $1 \times 1 $ 卷積層,進行降維,然後經過 $3 \times 3 $ 卷積層(如果是layer
的第一個Bottleneck
,那麼 $3 \times 3 $ 卷積層的stride=2
,作用是進行降取樣),最後經過 $1 \times 1 $ 卷積層,進行升維 。
ResNet 18 圖解
layer1
下面是 ResNet 18
,使用的是 BasicBlock
。layer1
,特點是沒有進行降取樣,卷積層的 stride = 1
,不會降取樣。在進行 shortcut
連線時,也沒有經過 downsample
層。
layer2,layer3,layer4
而 layer2
,layer3
,layer4
的結構圖如下,每個 layer
包含 2 個 BasicBlock
,但是第 1 個 BasicBlock
的第 1 個卷積層的 stride = 2
,會進行降取樣。在進行 shortcut
連線時,會經過 downsample
層,進行降取樣和降維。
ResNet 50 圖解
layer1
在 layer1
中,首先第一個 Bottleneck
只會進行升維,不會降取樣。shortcut
連線前,會經過 downsample
層升維處理。第二個 Bottleneck
的 shortcut
連線不會經過 downsample
層。
layer2,layer3,layer4
而 layer2
,layer3
,layer4
的結構圖如下,每個 layer
包含多個 Bottleneck
,但是第 1 個 Bottleneck
的 \(3 \times 3\) 卷積層的 stride = 2
,會進行降取樣。在進行 shortcut
連線時,會經過 downsample
層,進行降取樣和降維。
如果你覺得這篇文章對你有幫助,不妨點個贊,讓我有更多動力寫出好文章。