赞
踩
整理下特征提取网络resnet的网络结构
论文地址:https://arxiv.org/abs/1512.03385
文中所提供的代码来自:https://github.com/open-mmlab/mmclassification
有5个输出层C1,C2,C3,C4,C5,其中常用的是C2,C3,C4,C5层。没有单独的层进行下采样,直接在残差的时候进行下采样。
整个resnet50的forward代码如下(示例):
def forward(self, x): """Forward function.""" if self.deep_stem: # teem层 x = self.stem(x) else: x = self.conv1(x) x = self.norm1(x) x = self.relu(x) x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) #获取相应名字的layer层:layer0,layer1... x = res_layer(x) # 进行操作 if i in self.out_indices: # 输出索引,指定输出的层数,用于后续的FPN操作。 outs.append(x) return tuple(outs)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
对应结构图:
def forward(self, x): """Forward function.""" def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv1_plugin_names) out = self.conv2(out) out = self.norm2(out) out = self.relu(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv2_plugin_names) out = self.conv3(out) out = self.norm3(out) if self.with_plugins: out = self.forward_plugin(out, self.after_conv3_plugin_names) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out
Resnet的残差模块,使得神经网络能够有效的减轻梯度因为网络层数的逐渐加深而导致的梯度消失的问题。是一个十分经典的特征提取网络模块,后面还有基于resnet的res2net,resnest和resnext的改进。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。