当前位置:   article > 正文

广泛使用的Residual Block

residual block

ResNet 的核心思想是引入一个所谓的「恒等快捷连接」(identity shortcut connection),直接跳过一个或多个层,如下图所示:

 

 ImageNet的一个更深层次的残差函数F。

左图:一个积木块,BasicBlock,用于ResNet-34。右图:ResNet-50/101/152的bottleneck构建块。

在这里插入图片描述

BasicBlock

expansion是残差结构中输出维度是输入维度的多少倍,BasicBlock没有升维,所以expansion = 1
残差结构是在求和之后才经过ReLU层

 

  1. class BasicBlock(nn.Module):
  2. expansion = 1
  3. def __init__(self, inplanes, planes, stride=1, downsample=None):
  4. super(BasicBlock, self).__init__()
  5. self.conv1 = conv3x3(inplanes, planes, stride)
  6. self.bn1 = nn.BatchNorm2d(planes)
  7. self.relu = nn.ReLU(inplace=True)
  8. self.conv2 = conv3x3(planes, planes)
  9. self.bn2 = nn.BatchNorm2d(planes)
  10. self.downsample = downsample
  11. self.stride = stride
  12. def forward(self, x):
  13. identity = x
  14. out = self.conv1(x)
  15. out = self.bn1(out)
  16. out = self.relu(out)
  17. out = self.conv2(out)
  18. out = self.bn2(out)
  19. if self.downsample is not None:
  20. identity = self.downsample(x)
  21. out += identity
  22. out = self.relu(out)
  23. return out

bottleneck

注意Res18、Res34用的是BasicBlock,其余用的是Bottleneck。使用Bottleneck的目的为降低通道维的数量,提高速度。可简化为“降-卷-升”,一般expansion = 4,因为Bottleneck中每个残差结构输出维度都是输入维度的4倍。

  1. class Bottleneck(nn.Module):
  2. expansion = 4
  3. def __init__(self, inplanes, planes, stride=1, downsample=None):
  4. super(Bottleneck, self).__init__()
  5. self.conv1 = conv1x1(inplanes, planes)
  6. self.bn1 = nn.BatchNorm2d(planes)
  7. self.conv2 = conv3x3(planes, planes, stride)
  8. self.bn2 = nn.BatchNorm2d(planes)
  9. self.conv3 = conv1x1(planes, planes * self.expansion)
  10. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  11. self.relu = nn.ReLU(inplace=True)
  12. self.downsample = downsample
  13. self.stride = stride
  14. def forward(self, x):
  15. identity = x
  16. out = self.conv1(x)
  17. out = self.bn1(out)
  18. out = self.relu(out)
  19. out = self.conv2(out)
  20. out = self.bn2(out)
  21. out = self.relu(out)
  22. out = self.conv3(out)
  23. out = self.bn3(out)
  24. if self.downsample is not None:
  25. identity = self.downsample(x)
  26. out += identity
  27. out = self.relu(out)
  28. return out

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/195907
推荐阅读
相关标签
  

闽ICP备14008679号