当前位置:   article > 正文

经典网络模型---MobileNet三代模型架构之V3_mobilenetv3 build model怎么写

mobilenetv3 build model怎么写

MobileNet V3

1)引入Squeeze- Excitation结构

2)非线性变换改变, h-swish替换swish

SE-net

SE-net整体的结构,它可以融入到任何网络模型

S:操作

对特征图采取全局平均池化,得到1*1*C的结果

特征图中每个通道都相当于描述了一部分特征,操作后相当于是全局的

E:Excitation操作

想得到每个特征图的重要程度评分,还需要再来两个全连接层,最终整个结果也是1*1*C,相当于attnetion

  1. class hsigmoid(nn.Module):
  2. def forward(self, x):
  3. out = F.relu6(x + 3, inplace=True) / 6
  4. return out
  5. class SeModule(nn.Module):
  6. def __init__(self, in_size, reduction=4):
  7. super(SeModule, self).__init__()
  8. self.se = nn.Sequential(
  9. nn.AdaptiveAvgPool2d(1),
  10. nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
  11. nn.BatchNorm2d(in_size // reduction),
  12. nn.ReLU(inplace=True),
  13. nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
  14. nn.BatchNorm2d(in_size),
  15. hsigmoid()
  16. )
  17. def forward(self, x):
  18. return x * self.se(x)

MobileNet V2和V3对比

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn import init
  5. from base import BaseModel
  6. class hswish(nn.Module):
  7. def forward(self, x):
  8. out = x * F.relu6(x + 3, inplace=True) / 6
  9. return out
  10. class Block(nn.Module):
  11. '''expand + depthwise + pointwise'''
  12. def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
  13. super(Block, self).__init__()
  14. self.stride = stride
  15. self.se = semodule
  16. self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
  17. self.bn1 = nn.BatchNorm2d(expand_size)
  18. self.nolinear1 = nolinear
  19. self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
  20. self.bn2 = nn.BatchNorm2d(expand_size)
  21. self.nolinear2 = nolinear
  22. self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
  23. self.bn3 = nn.BatchNorm2d(out_size)
  24. self.shortcut = nn.Sequential()
  25. if stride == 1 and in_size != out_size:
  26. self.shortcut = nn.Sequential(
  27. nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
  28. nn.BatchNorm2d(out_size),
  29. )
  30. def forward(self, x):
  31. out = self.nolinear1(self.bn1(self.conv1(x)))
  32. out = self.nolinear2(self.bn2(self.conv2(out)))
  33. out = self.bn3(self.conv3(out))
  34. if self.se != None:
  35. out = self.se(out)
  36. out = out + self.shortcut(x) if self.stride==1 else out
  37. return out
  38. class MobileNetV3_Large(BaseModel):
  39. def __init__(self, num_classes=1000):
  40. super(MobileNetV3_Large, self).__init__()
  41. self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)
  42. self.bn1 = nn.BatchNorm2d(16)
  43. self.hs1 = hswish()
  44. self.bneck = nn.Sequential(
  45. Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
  46. Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
  47. Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
  48. Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
  49. Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
  50. Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
  51. Block(3, 40, 240, 80, hswish(), None, 2),
  52. Block(3, 80, 200, 80, hswish(), None, 1),
  53. Block(3, 80, 184, 80, hswish(), None, 1),
  54. Block(3, 80, 184, 80, hswish(), None, 1),
  55. Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
  56. Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
  57. Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
  58. Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
  59. Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
  60. )
  61. self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
  62. self.bn2 = nn.BatchNorm2d(960)
  63. self.hs2 = hswish()
  64. self.linear3 = nn.Linear(960, 1280)
  65. self.bn3 = nn.BatchNorm1d(1280)
  66. self.hs3 = hswish()
  67. self.linear4 = nn.Linear(1280, num_classes)
  68. self.init_params()
  69. def init_params(self):
  70. for m in self.modules():
  71. if isinstance(m, nn.Conv2d):
  72. init.kaiming_normal_(m.weight, mode='fan_out')
  73. if m.bias is not None:
  74. init.constant_(m.bias, 0)
  75. elif isinstance(m, nn.BatchNorm2d):
  76. init.constant_(m.weight, 1)
  77. init.constant_(m.bias, 0)
  78. elif isinstance(m, nn.Linear):
  79. init.normal_(m.weight, std=0.001)
  80. if m.bias is not None:
  81. init.constant_(m.bias, 0)
  82. def forward(self, x):
  83. x = 1
  84. out = self.hs1(self.bn1(self.conv1(x)))
  85. out = self.bneck(out)
  86. out = self.hs2(self.bn2(self.conv2(out)))
  87. out = F.avg_pool2d(out, 7)
  88. out = out.view(out.size(0), -1)
  89. out = self.hs3(self.bn3(self.linear3(out)))
  90. out = self.linear4(out)
  91. return out

 效果对比

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/161689
推荐阅读
相关标签
  

闽ICP备14008679号