当前位置:   article > 正文

pytorch Resnet-18源码解读_resnet18(pretrained=true)

resnet18(pretrained=true)

ResNet-18网络结构图

ResNet是微软研究院He KaiMing等人提出的。论文链接:Deep Residual Learning for Image Recognition

ResNet代码

pytorch中定义了:

_all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']
  • 1
  • 2
  • 3

Resnet 声明

这里只介绍ResNet-18。其调用方法:

from torchvision import models
resnet_18 = models.resnet18(pretrained=True)
  • 1
  • 2

ResNet

其中pretrained表示是否载入在Image net上的与训练模型。ResNet18模型的定义如下:

def resnet18(pretrained=False, progress=True, **kwargs):
   r"""ResNet-18 model from
   `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

   Args:
       pretrained (bool): If True, returns a model pre-trained on ImageNet
       progress (bool): If True, displays a progress bar of the download to stderr
   """
   return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                  **kwargs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

resnet18调用类的私有函数_resnet , _resnet定义如下:

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    """
    arch: 网络名字
    block: 残差块类型,定义了BasicBlock与Bottleneck两种
    layers: 每个stage中残差块的数目,长度为4
    """
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

ResNet代码如下:

class ResNet(nn.Module):

   def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                groups=1, width_per_group=64, replace_stride_with_dilation=None,
                norm_layer=None):
       """
       block: 残差块类型,定义了BasicBlock与Bottleneck两种
		layers: 每个stage中残差块的数目,长度为4
   	num_classes: 类别数目
   	zero_init_residual:若为True,则将残差块的最后一个BN层初始化为0,
   	这样残差分支从0开始每一个残差分支,每一个残差块表现的像一个恒等映射
   	根据论文:网络可提升0.2%~0.3%
           """
       super(ResNet, self).__init__()
       if norm_layer is None:
           norm_layer = nn.BatchNorm2d
       self._norm_layer = norm_layer

       self.inplanes = 64
       self.dilation = 1
       if replace_stride_with_dilation is None:
           # each element in the tuple indicates if we should replace
           # the 2x2 stride with a dilated convolution instead
           replace_stride_with_dilation = [False, False, False]
       if len(replace_stride_with_dilation) != 3:
           raise ValueError("replace_stride_with_dilation should be None "
                            "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
       self.groups = groups
       self.base_width = width_per_group
       self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                              bias=False)
       self.bn1 = norm_layer(self.inplanes)
       self.relu = nn.ReLU(inplace=True)
       self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
       self.layer1 = self._make_layer(block, 64, layers[0])
       self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                      dilate=replace_stride_with_dilation[0])
       self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                      dilate=replace_stride_with_dilation[1])
       self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                      dilate=replace_stride_with_dilation[2])
       self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
       self.fc = nn.Linear(512 * block.expansion, num_classes)

       for m in self.modules():
           if isinstance(m, nn.Conv2d):
               nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
           elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
               nn.init.constant_(m.weight, 1)
               nn.init.constant_(m.bias, 0)

       # Zero-initialize the last BN in each residual branch,
       # so that the residual branch starts with zeros, and each residual block behaves like an identity.
       # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
       if zero_init_residual:
           for m in self.modules():
               if isinstance(m, Bottleneck):
                   nn.init.constant_(m.bn3.weight, 0)
               elif isinstance(m, BasicBlock):
                   nn.init.constant_(m.bn2.weight, 0)

   def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
       norm_layer = self._norm_layer
       downsample = None
       previous_dilation = self.dilation
       if dilate:
           self.dilation *= stride
           stride = 1
       if stride != 1 or self.inplanes != planes * block.expansion:
           downsample = nn.Sequential(
               conv1x1(self.inplanes, planes * block.expansion, stride),
               norm_layer(planes * block.expansion),
           )

       layers = []
       layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                           self.base_width, previous_dilation, norm_layer))
       self.inplanes = planes * block.expansion
       for _ in range(1, blocks):
           layers.append(block(self.inplanes, planes, groups=self.groups,
                               base_width=self.base_width, dilation=self.dilation,
                               norm_layer=norm_layer))

       return nn.Sequential(*layers)

   def _forward_impl(self, x):
       # See note [TorchScript super()]
       x = self.conv1(x)
       x = self.bn1(x)
       x = self.relu(x)
       x = self.maxpool(x)

       x = self.layer1(x)
       x = self.layer2(x)
       x = self.layer3(x)
       x = self.layer4(x)

       x = self.avgpool(x)
       x = torch.flatten(x, 1)
       x = self.fc(x)

       return x

   def forward(self, x):
       return self._forward_impl(x)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106

在resnet开始会有一个7x7的卷积核来做一次2x下采样,其后用maxpooling再做一次2x下采样。其后 会有4个layer,由_make_layer实现,最后是全连接层。下面介绍make_layer的实现。

make_layer定义

由代码注释知道:block为block类型,针对不同层数的resnet网络有BasicBlock与Bottleneck两种;planes是第一个卷积核的输出通道数;blocks是Int类型,指得是本个Make_layer包含block的个数

残差块定义

在前面注释提到,resnet中block有两种:BasicBlock与Bottleneck两种

  1. BasicBlock是resnet18 与resnet34的残差结构块
  2. Bottleneck是resnet50,resnet101与resnet152的残差块结构

先介绍BasicBlock:

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        """
        inplanes: 输入的通道树,int
        planes:
        stride:卷积层的步长
        downsample: 分支下采样(nn.Sequential)
        """
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55

每个block里依次包含:conv3x3, bn, relu, conv3x3, bn。在forward中用out+=x实现短接。如果参数stride=2,则会在第一个conv3x3中出现下采样。那么需要赋值downsample,down sample也是一个3x3卷积。

ResNet结构图

引用: https://www.jianshu.com/p/085f4c8256f1

ResNet结构图

最后

第一次发博客,整理不全面的忘指出。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/217586
推荐阅读
相关标签
  

闽ICP备14008679号