VGG網路的Pytorch實現

ysyouaremyall發表於2019-05-08

1.文章原文地址

Very Deep Convolutional Networks for Large-Scale Image Recognition

2.文章摘要

在這項工作中,我們研究了在大規模的影像識別資料集上卷積神經網路的深度對準確率的影響。我們主要貢獻是使用非常小(3×3)卷積核的架構對深度增加的網路進行全面的評估,其結果表明將深度增大到16-19層時網路的效能會顯著提升。這些發現是基於我們在ImageNet Challenge 2014的目標檢測和分類任務分別獲得了第一名和第二名的成績而得出的。另外該網路也可以很好的推廣到其他資料集上,在這些資料集上獲得了當前最好結果。我們已經公開了效能最佳的ConvNet模型,為了促進在計算機視覺中使用深度視覺表徵的進一步研究。

3.網路結構

4.Pytorch實現

  1 import torch.nn as nn
  2 try:
  3     from torch.hub import load_state_dict_from_url
  4 except ImportError:
  5     from torch.utils.model_zoo import load_url as load_state_dict_from_url
  6 
  7 __all__ = [
  8     'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
  9     'vgg19_bn', 'vgg19',
 10 ]
 11 
 12 
 13 model_urls = {
 14     'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
 15     'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
 16     'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
 17     'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
 18     'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
 19     'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
 20     'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
 21     'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
 22 }
 23 
 24 
 25 class VGG(nn.Module):
 26 
 27     def __init__(self, features, num_classes=1000, init_weights=True):
 28         super(VGG, self).__init__()
 29         self.features = features
 30         self.avgpool = nn.AdaptiveAvgPool2d((7, 7))  #固定全連線層的輸入
 31         self.classifier = nn.Sequential(
 32             nn.Linear(512 * 7 * 7, 4096),
 33             nn.ReLU(True),
 34             nn.Dropout(),
 35             nn.Linear(4096, 4096),
 36             nn.ReLU(True),
 37             nn.Dropout(),
 38             nn.Linear(4096, num_classes),
 39         )
 40         if init_weights:
 41             self._initialize_weights()
 42 
 43     def forward(self, x):
 44         x = self.features(x)
 45         x = self.avgpool(x)
 46         x = x.view(x.size(0), -1)
 47         x = self.classifier(x)
 48         return x
 49 
 50     def _initialize_weights(self):
 51         for m in self.modules():
 52             if isinstance(m, nn.Conv2d):
 53                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
 54                 if m.bias is not None:
 55                     nn.init.constant_(m.bias, 0)
 56             elif isinstance(m, nn.BatchNorm2d):
 57                 nn.init.constant_(m.weight, 1)
 58                 nn.init.constant_(m.bias, 0)
 59             elif isinstance(m, nn.Linear):
 60                 nn.init.normal_(m.weight, 0, 0.01)
 61                 nn.init.constant_(m.bias, 0)
 62 
 63 
 64 def make_layers(cfg, batch_norm=False):
 65     layers = []
 66     in_channels = 3
 67     for v in cfg:
 68         if v == 'M':
 69             layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
 70         else:
 71             conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
 72             if batch_norm:
 73                 layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
 74             else:
 75                 layers += [conv2d, nn.ReLU(inplace=True)]
 76             in_channels = v
 77     return nn.Sequential(*layers)
 78 
 79 
 80 cfgs = {
 81     'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
 82     'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
 83     'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
 84     'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
 85 }
 86 
 87 
 88 def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
 89     if pretrained:
 90         kwargs['init_weights'] = False
 91     model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
 92     if pretrained:
 93         state_dict = load_state_dict_from_url(model_urls[arch],
 94                                               progress=progress)
 95         model.load_state_dict(state_dict)
 96     return model
 97 
 98 
 99 def vgg11(pretrained=False, progress=True, **kwargs):
100     """VGG 11-layer model (configuration "A")
101     Args:
102         pretrained (bool): If True, returns a model pre-trained on ImageNet
103         progress (bool): If True, displays a progress bar of the download to stderr
104     """
105     return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
106 
107 
108 def vgg11_bn(pretrained=False, progress=True, **kwargs):
109     """VGG 11-layer model (configuration "A") with batch normalization
110     Args:
111         pretrained (bool): If True, returns a model pre-trained on ImageNet
112         progress (bool): If True, displays a progress bar of the download to stderr
113     """
114     return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
115 
116 
117 def vgg13(pretrained=False, progress=True, **kwargs):
118     """VGG 13-layer model (configuration "B")
119     Args:
120         pretrained (bool): If True, returns a model pre-trained on ImageNet
121         progress (bool): If True, displays a progress bar of the download to stderr
122     """
123     return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
124 
125 
126 def vgg13_bn(pretrained=False, progress=True, **kwargs):
127     """VGG 13-layer model (configuration "B") with batch normalization
128     Args:
129         pretrained (bool): If True, returns a model pre-trained on ImageNet
130         progress (bool): If True, displays a progress bar of the download to stderr
131     """
132     return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
133 
134 
135 def vgg16(pretrained=False, progress=True, **kwargs):
136     """VGG 16-layer model (configuration "D")
137     Args:
138         pretrained (bool): If True, returns a model pre-trained on ImageNet
139         progress (bool): If True, displays a progress bar of the download to stderr
140     """
141     return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
142 
143 
144 def vgg16_bn(pretrained=False, progress=True, **kwargs):
145     """VGG 16-layer model (configuration "D") with batch normalization
146     Args:
147         pretrained (bool): If True, returns a model pre-trained on ImageNet
148         progress (bool): If True, displays a progress bar of the download to stderr
149     """
150     return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
151 
152 
153 def vgg19(pretrained=False, progress=True, **kwargs):
154     """VGG 19-layer model (configuration "E")
155     Args:
156         pretrained (bool): If True, returns a model pre-trained on ImageNet
157         progress (bool): If True, displays a progress bar of the download to stderr
158     """
159     return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
160 
161 
162 def vgg19_bn(pretrained=False, progress=True, **kwargs):
163     """VGG 19-layer model (configuration 'E') with batch normalization
164     Args:
165         pretrained (bool): If True, returns a model pre-trained on ImageNet
166         progress (bool): If True, displays a progress bar of the download to stderr
167     """
168     return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)

 參考

https://github.com/pytorch/vision/tree/master/torchvision/models

相關文章