当前位置:   article > 正文

倒置残差的理解

倒置残差

倒残差与线性瓶颈浅析 - MobileNetV2_晓野豬的博客-CSDN博客_倒残差结构 

首先理解倒置残差要先了解残差形式

残差结构

1.采用1*1卷积降维,比如输入是256,降维到64

2.采用卷积核为3*3形式

3.采用1*1卷积升维。,比如64变成256

了解完残差结构后,现在开始学习倒置残差结构

1.采用1*1卷积升维,比如输入是64,降维到256

2.采用卷积核为深度可分离的3*3形式

3.采用1*1卷积 降维。比如256,降维到64

这里的激活函数采用的是relu6

  1. # _*_coding:utf-8_*_
  2. import torch
  3. import torch.nn as nn
  4. class InvertedResidualsBlock(nn.Module):
  5. def __init__(self, in_channels, out_channels, expansion, stride):
  6. super(InvertedResidualsBlock, self).__init__()
  7. channels = expansion * in_channels
  8. self.stride = stride
  9. self.basic_block = nn.Sequential(
  10. nn.Conv2d(in_channels, channels, kernel_size=1, stride=1, bias=False),
  11. nn.BatchNorm2d(channels),
  12. nn.ReLU6(inplace=True),
  13. nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, groups=channels, bias=False),
  14. nn.BatchNorm2d(channels),
  15. nn.ReLU6(inplace=True),
  16. nn.Conv2d(channels, out_channels, kernel_size=1, stride=1, bias=False),
  17. nn.BatchNorm2d(out_channels)
  18. )
  19. # The shortcut operation does not affect the number of channels
  20. self.shortcut = nn.Sequential(
  21. nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
  22. nn.BatchNorm2d(out_channels)
  23. )
  24. def forward(self, x):
  25. out = self.basic_block(x)
  26. if self.stride == 1:
  27. print("With shortcut!")
  28. out = out + self.shortcut(x)
  29. else:
  30. print("No shortcut!")
  31. print(out.size())
  32. return out
  33. if __name__ == "__main__":
  34. x = torch.randn(16, 3, 32, 32)
  35. # no shortcut
  36. net1 = InvertedResidualsBlock(3, 6, 6, 2)
  37. # with shortcut
  38. net2 = InvertedResidualsBlock(3, 6, 6, 1)
  39. y1, y2 = net1(x), net2(x)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/380961
推荐阅读
相关标签
  

闽ICP备14008679号