赞
踩
倒残差与线性瓶颈浅析 - MobileNetV2_晓野豬的博客-CSDN博客_倒残差结构
首先理解倒置残差要先了解残差形式
残差结构:
1.采用1*1卷积降维,比如输入是256,降维到64
2.采用卷积核为3*3形式
3.采用1*1卷积升维。,比如64变成256
了解完残差结构后,现在开始学习倒置残差结构
1.采用1*1卷积升维,比如输入是64,降维到256
2.采用卷积核为深度可分离的3*3形式
3.采用1*1卷积 降维。比如256,降维到64
这里的激活函数采用的是relu6
- # _*_coding:utf-8_*_
- import torch
- import torch.nn as nn
-
-
- class InvertedResidualsBlock(nn.Module):
- def __init__(self, in_channels, out_channels, expansion, stride):
- super(InvertedResidualsBlock, self).__init__()
- channels = expansion * in_channels
- self.stride = stride
-
- self.basic_block = nn.Sequential(
- nn.Conv2d(in_channels, channels, kernel_size=1, stride=1, bias=False),
- nn.BatchNorm2d(channels),
- nn.ReLU6(inplace=True),
- nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, groups=channels, bias=False),
- nn.BatchNorm2d(channels),
- nn.ReLU6(inplace=True),
- nn.Conv2d(channels, out_channels, kernel_size=1, stride=1, bias=False),
- nn.BatchNorm2d(out_channels)
- )
- # The shortcut operation does not affect the number of channels
- self.shortcut = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
- nn.BatchNorm2d(out_channels)
- )
-
- def forward(self, x):
- out = self.basic_block(x)
- if self.stride == 1:
- print("With shortcut!")
- out = out + self.shortcut(x)
- else:
- print("No shortcut!")
- print(out.size())
-
- return out
-
-
- if __name__ == "__main__":
- x = torch.randn(16, 3, 32, 32)
- # no shortcut
- net1 = InvertedResidualsBlock(3, 6, 6, 2)
- # with shortcut
- net2 = InvertedResidualsBlock(3, 6, 6, 1)
- y1, y2 = net1(x), net2(x)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。