当前位置:   article > 正文

ssd网络结构_SSD 源码实现详细解析 (PyTorch)-模型定义篇

ssd神经网络源码

概览

SSD 和 YOLO 都是非常主流的 one-stage 目标检测模型, 并且相对于 two-stage 的 RCNN 系列来说, SSD 的实现更加的简明易懂, 接下来我将从以下几个方面展开对 SSD 模型的源码实现讲解: - 模型结构定义 - DefaultBox 生成候选框 - 解析预测结果 - MultiBox 损失函数 - Augmentations Trick - 模型训练 - 模型预测 - 模型验证 - 其他辅助代码

可以看出, 虽然 SSD 模型本身并不复杂, 但是也正是由于 one-stage 模型较简单的原因, 其检测的准确率相对于 two-stage 模型较低, 因此, 通常需要借助许多训练和检测时的 Tricks 来提升模型的精确度, 这些代码我们会放在第三部分讲解. 下面, 我们按照顺序首先对 SSD 模型结构定义的源码进行解析.(项目地址: https://github.com/amdegroot/ssd.pytorch)

模型结构定义

本部分代码主要位于 ssd.py 文件里面, 在本文件中, 定义了SSD的模型结构. 主要包含以下类和函数, 整体概览如下:

  1. # ssd.py
  2. class SSD(nn.Module): # 自定义SSD网络
  3. def __init__(self, phase, size, base, extras, head, num_classes):
  4. # ... SSD 模型初始化
  5. def forward(self, x):
  6. # ... 定义forward函数, 将设计好的layers和ops应用到输入图片 x 上
  7. def load_weights(self, base_file):
  8. # ... 加载参数权重值
  9. def vgg(cfg, i, batch_norm=False):
  10. # ... 搭建vgg网络
  11. def add_extras(cfg, i, batch_norm=False):
  12. # ... 向VGG网络中添加额外的层用于feature scaling
  13. def multibox(vgg, extra_layers, cfg, num_classes):
  14. # ... 构建multibox结构
  15. base = {
  16. ...} # vgg 网络结构参数
  17. extras = {
  18. ...} # extras 层参数
  19. mbox = {
  20. ...} # multibox 相关参数
  21. def build_ssd(phase, size=300, num_classes=21):
  22. # ... 构建模型函数, 调用上面的函数进行构建

为了方便理解, 我们不按照文件中的定义顺序解析, 而是根据文件中函数的调用关系来从外而内, 从上而下的进行解析, 解析顺序如下: - build_ssd(...) 函数 - vgg(...) 函数 - add_extras(...) 函数 - multibox(...) 函数 - SSD(nn.Module) 类

build_ssd(...) 函数

在其他文件通常利用build_ssd(phase, size=300, num_classes=21)函数来创建模型, 下面先看看该函数的具体实现:

  1. # ssd.py
  2. class SSD(nn.Module): # 自定义SSD网络
  3. def __init__(self, phase, size, base, extras, head, num_classes):
  4. # ...
  5. def forward(self, x):
  6. # ...
  7. def load_weights(self, base_file):
  8. # ...
  9. def vgg(cfg, i, batch_norm=False):
  10. # ... 搭建vgg网络
  11. def add_extras(cfg, i, batch_norm=False):
  12. # ... 向VGG网络中添加额外的层用于feature scaling
  13. def multibox(vgg, extra_layers, cfg, num_classes):
  14. # ... 构建multibox结构
  15. base = {
  16. # vgg 网络结构参数
  17. '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512],
  18. '500': []
  19. }
  20. extras = {
  21. # extras 层参数
  22. '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],
  23. '500': []
  24. }
  25. mbox = {
  26. # multibox 相关参数
  27. '300': [4, 6, 6, 6, 4, 4],
  28. '500': []
  29. }
  30. def build_ssd(phase, size=300, num_classes=21):
  31. # 构建模型函数, 调用上面的函数进行构建
  32. if phase != "test" and phase != "train": # 只能是训练或者预测阶段
  33. print("ERROR: Phase: " + phase + " not recognized")
  34. return
  35. if size != 300:
  36. print("ERROR: You specified size " + repr(size) + ". However, "+
  37. "currently only SSD300 is supported!") # 仅仅支持300size的SSD
  38. return
  39. base_, extras_, head_ = multibox(vgg(base[str(size)], 3),
  40. add_extras(extras[str(size), 1024),
  41. mbox[str(size)], num_classes )
  42. return SSD(phase, size, base_, extras_, head_, num_classes)

可以看到, build_ssd(...)函数主要使用了multibox(...)函数来获取base_, extras_, head_, 在调用multibox(...)函数的同时, 还分别调用了vgg(...)函数, add_extras(...)函数, 并将其返回值作为参数. 之后, 利用这些信息初始化了SSD网络. 那么下面, 我们就先查看一下这些函数定义和作用

vgg(...) 函数

我们以调用顺序为依据, 先对multibox(...)函数的内部实现进行解析, 但是在查看multibox(...)函数之前, 我们首先需要看看其参数的由来, 首先是vgg(...)函数, 因为 SSD 是以 VGG 网络作为 backbone 的, 因此该函数主要定义了 VGG 网络的结果, 根据调用语句vgg(base[str(size)], 3)可以看出, 调用vgg时向其传入了两个参数, 分别为base[str(size)]3, 对应的就是base['300']和3.

  1. # ssd.py
  2. def vgg(cfg, i, batch_norm = False):
  3. # cfg = base['300'] = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512],
  4. # i = 3
  5. layers = []
  6. in_channels = i
  7. for v in cfg:
  8. if v == 'M':
  9. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  10. if v == 'C':
  11. layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
  12. else:
  13. conv2d = nn.Conv2d(in_channels=in_channels, out_channels=v, kernel_size=3, padding=1)
  14. if batch_norm:
  15. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
  16. else:
  17. layers += [conv2d, nn.ReLU(inplace=True)]
  18. in_channels = v
  19. pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
  20. conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
  21. conv7 = nn.Con2d(1024, 1024, kernel_size=1)
  22. layers += [pool5, conv6, nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
  23. return layers

上面的写法是 ssd.pytorch 代码中的原始写法, 代码风格体现了 PyTorch 灵活的编程特性, 但是这种写法不是那么直观, 需要很详细的解读才能看出来这个网络的整个结构是什么样的. 建议大家结合 VGG 网络的整个结构来解读这部分代码, 核心思想就是通过预定义的 cfg=base={...} 里面的参数来设置 vgg 网络卷积层和池化层的参数设置, 由于 vgg 网络的模型结构很经典, 有很多文章都写的很详细, 这里就不再啰嗦了, 我们主要来看一下 SSD 网络中比较重要的点, 也就是下面的 extras_layers.

add_extras(...) 函数

想必了解 SSD 模型的朋友都知道, SSD 模型中是利用多个不同层级上的 feature map 来进行同时进行边框回归和物体分类任务的, 除了使用 vgg 最深层的卷积层以外, SSD 还添加了几个卷积层, 专门用于执行回归和分类任务(如文章开头图2所示), 因此, 我们在定义完 VGG 网络以后, 需要额外定义这些新添加的卷积层. 接下来, 我们根据论文中的参数设置, 来看一下 add_extras(...) 的内部实现, 根据调用语句add_extras(extras[str(size)], 1024) 可知, 该函数中参数cfg = extras['300'], i=1024.

  1. # ssd.py
  2. def add_extras(cfg, i, batch_norm=False):
  3. # cfg = [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256]
  4. # i = 1024
  5. layers = []
  6. in_channels = i
  7. flag = False
  8. for k, v in enumerate(cfg):
  9. if in_channels != 'S':
  10. if v == 'S': # (1,3)[True] = 3, (1,3)[False] = 1
  11. layers += [nn.Conv2d(in_channels=in_channels, out_channels=cfg[k+1],
  12. kernel_size=(1, 3)[flag], stride=2, padding=1)]
  13. else:
  14. layers += [nn.Conv2d(in_channels=in_channels, out_channels=v,
  15. kernel_size=(1, 3)[flag])]
  16. flag = not flag
  17. in_channels = v
  18. return layers

注意, 在extras中, 卷积层之间并没有使用 BatchNorm 和 ReLU, 实际上, ReLU 的使用放在了forward函数中

同样的问题, 上面的定义不是很直观, 因此我将上面的代码用 PyTorch 重写了, 重写后的代码更容易看出网络的结构信息, 同时可读性也较强, 代码如下所示(与上面的代码完全等价):

  1. def add_extras():
  2. exts1_1 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)
  3. exts1_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
  4. exts2_1 = nn.Conv2d(512, 128, 1, 1, 0)
  5. exts2_2 = nn.Conv2d(128, 256, 3, 2, 1)
  6. exts3_1 = nn.Conv2d(256, 128, 1, 1, 0)
  7. exts3_2 = nn.Conv2d(128, 256, 3, 1, 0)
  8. exts4_1 = nn.Conv2d(256, 128, 1, 1, 0)
  9. exts4_2 = nn.Conv2d(128, 256, 3, 1, 0)
  10. return [exts1_1, exts1_2, exts2_1, exts2_2, exts3_1, exts3_2, exts4_1, exts4_2]

在定义完整个的网络结构以后, 我们就需要定义最后的 head 层, 也就是特定的任务层, 因为 SSD 是 one-stage 模型, 因此它是同时在特征图谱上产生预测边框和预测分类的, 我们根据类别的数量来设置相应的网络预测层参数, 注意需要用到多个特征图谱, 也就是说要有多个预测层(原文中用了6个卷积特征图谱, 其中2个来自于 vgg 网络, 4个来自于 extras 层), 代码实现如下:

multibox(...) 函数

multibox(...) 总共有4个参数, 现在我们已经得到了两个参数, 分别是vgg(...)函数返回的layers, 以及add_extras(...)函数返回的layers, 后面两个参数根据调用语句可知分别为mbox[str(size)](mbox['300'])和num_classes(默认为21). 下面, 看一下multibox(...)函数的具体内部实现:

  1. # ssd.py
  2. def multibox(vgg, extra_layers, cfg, num_classes):
  3. # cfg = [4, 6, 6, 6, 4, 4]
  4. # num_classes = 21
  5. # ssd总共会选择6个卷积特征图谱进行预测, 分别为, vggnet的conv4_3, 以及extras_layers的5段卷积的输出(每段由两个卷积层组成, 具体可看extras_layers的实现).
  6. # 也就是说, loc_layers 和 conf_layers 分别具有6个预测层.
  7. loc_layers = []
  8. conf_layers = []
  9. vgg_source = [21, -2]
  10. for k, v in enumerate(vgg_source):
  11. loc_layers += [nn.Conv2d(vgg[v].out_channels, cfg[k]*4, kernel_size=3, padding=1]
  12. conf_layers += [nn.Conv2d(vgg[v].out_channels, cfg[k]*num_classes, kernel_size=3, padding=1)]
  13. for k, v in enumerate(extra_layers[1::2], 2):
  14. loc_layers += [nn.Conv2d(v.out_channels, cfg[k]*4, kernel_size=3, padding=1)]
  15. conf_layers += [nn.Conv2d(v.out_channels, cfg[k]*num_classes, kernel_size=3, padding=1)]
  16. return vgg, extra_layers, (loc_layers, conf_layers)

同样, 我们可以将上面的代码写成可读性更强的形式:

  1. # ssd.py
  2. def multibox(vgg, extras, num_classes):
  3. loc_layers = []
  4. conf_layers = []
  5. #vgg_source=[21, -2] # 21 denote conv4_3, -2 denote conv7
  6. # 定义6个坐标预测层, 输出的通道数就是每个像素点上会产生的 default box 的数量
  7. loc1 = nn.Conv2d(vgg[21].out_channels, 4*4, 3, 1, 1) # 利用conv4_3的特征图谱, 也就是 vgg 网络 List 中的第 21 个元素的输出(注意不是第21层, 因为这中间还包含了不带参数的池化层).
  8. loc2 = nn.Conv2d(vgg[-2].out_channels, 6*4, 3, 1, 1) # Conv7
  9. loc3 = nn.Conv2d(vg
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/291202
推荐阅读
相关标签
  

闽ICP备14008679号