当前位置:   article > 正文

目标检测论文:FCOS: Fully Convolutional One-Stage Object Detection及其PyTorch实现

目标检测论文:FCOS: Fully Convolutional One-Stage Object Detection及其PyTorch实现

一阶段,没有anchor,没有proposal,内存占用少的目标检测算法。

PDF: https://arxiv.org/pdf/1904.01355.pdf
PyTorch: https://github.com/tianzhi0549/FCOS/
PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks
在这里插入图片描述

1 概述

本文创新点:

  1. 使用语义分割的思想来解决目标检测问题;
  2. 摒弃了目标检测中常见的anchor boxes和object proposal,使得不需要调优涉及anchor boxes和object proposal的超参数(hyper-parameters);
  3. 训练过程中避免大量计算GT boxes和anchor boxes 之间的IoU,使得训练过程占用内存更低;
  4. 提出的可以FCOS代替二阶段检测中的RPN,且性能更优;

2 FCOS框架解析

本文主要网络架构:
[Backbone] + [FPN] + [Classification+Regression+Center-ness]
在这里插入图片描述

2-1 提出的FCOS(Fully Convolutional One-Stage Object Detector)

对于feature map F i F_{i} Fi中点(x,y),映射回原图片中位置为 ( ⌊ s / 2 ⌋ + x ∗ s , ⌊ s / 2 ⌋ + y ∗ s ) (\left \lfloor s/2 \right \rfloor+x*s, \left \lfloor s/2 \right \rfloor+y*s) (s/2+xs,s/2+ys),基于anchor的检测方法将点 ( ⌊ s / 2 ⌋ + x ∗ s , ⌊ s / 2 ⌋ + y ∗ s ) (\left \lfloor s/2 \right \rfloor+x*s, \left \lfloor s/2 \right \rfloor+y*s) (s/2+xs,s/2+ys)作为中心,回归bounding box;FCOS将点作为训练样本回归bounding box。
FCOS回归一个4-D向量 t ∗ = ( l ∗ ; t ∗ ; r ∗ ; b ∗ ) t^{*} = (l^{*} ; t^{*} ; r^{*} ; b^{*} ) t=(l;t;r;b),其中 l ∗ l^{*} l t ∗ t^{*} t r ∗ r^{*} r b ∗ b^{*} b分别是点到bounding box上下左右的距离,回归函数为
在这里插入图片描述
与基于anchor的检测方法只将和GT重叠最大的anchor boxes作为正样本相比,FCOS能利用的正样本明显更多。最终的损失函数为:
在这里插入图片描述

2-2 多尺度预测(Multi-level Prediction with FPN)

使用基于FPN的多尺度预测提高召回率和缓解重叠bounding boxes带来的二义性。使用来自5层步长分别为8, 16, 32, 64 和 128的feature map P 3 , P 4 , P 5 , P 6 , P 7 {P3, P4,P5,P6,P7} P3,P4,P5,P6,P7,其中P6,P7分别是P5,P6的下采样。

不同于基于anchor的检测方法在不同层回归不同尺度的anchor boxes,FCOS指定每层回归的目标尺寸m2, m3, m4, m5, m6, m7 分别为 0, 64,128, 256, 512, ∞ \infty ,不满足每层目标回归尺寸的目标不会被回归,因此可以有效地减轻重叠目标带来的二义性(作者假设重叠目标大小差异较大)。

2-3 Center-ness

“center-ness”抑制低质量检测框的产生,快速过滤负样本,降低NMS负担,提高召回率和检测性能。center-ness用来度量当前位置和物体中心间的距离,即FCOS将点的坐标在目标中位置因素也加入考虑,越靠近中间权重越大。
在这里插入图片描述
在训练的过程中我们会约束中center-ness的值,使得其接近于0,使得分布在目标位置边缘的低质量框能够尽可能的靠近中心。在最终使用该网络的过程中,非极大值抑制(NMS)就可以轻松滤除这些低质量的边界框,提高检测性能。在这里插入图片描述

3 实验结果

在这里插入图片描述
PyTorch代码:

import torch
import torch.nn as nn
import torchvision

def Conv3x3ReLU(in_channels,out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
        nn.ReLU6(inplace=True)
    )

def locLayer(in_channels,out_channels):
    return nn.Sequential(
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
        )

def conf_centernessLayer(in_channels,out_channels):
    return nn.Sequential(
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
    )

class FCOS(nn.Module):
    def __init__(self, num_classes=21):
        super(FCOS, self).__init__()
        self.num_classes = num_classes
        resnet = torchvision.models.resnet50()
        layers = list(resnet.children())

        self.layer1 = nn.Sequential(*layers[:5])
        self.layer2 = nn.Sequential(*layers[5])
        self.layer3 = nn.Sequential(*layers[6])
        self.layer4 = nn.Sequential(*layers[7])

        self.lateral5 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=1)
        self.lateral4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)
        self.lateral3 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1)

        self.upsample4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.upsample3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)

        self.downsample6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.downsample5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)

        self.loc_layer3 = locLayer(in_channels=256,out_channels=4)
        self.conf_centerness_layer3 = conf_centernessLayer(in_channels=256,out_channels=self.num_classes+1)

        self.loc_layer4 = locLayer(in_channels=256, out_channels=4)
        self.conf_centerness_layer4 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes + 1)

        self.loc_layer5 = locLayer(in_channels=256, out_channels=4)
        self.conf_centerness_layer5 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes + 1)

        self.loc_layer6 = locLayer(in_channels=256, out_channels=4)
        self.conf_centerness_layer6 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes + 1)

        self.loc_layer7 = locLayer(in_channels=256, out_channels=4)
        self.conf_centerness_layer7 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes + 1)

        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.layer1(x)
        c3 =x = self.layer2(x)
        c4 =x = self.layer3(x)
        c5 = x = self.layer4(x)

        p5 = self.lateral5(c5)
        p4 = self.upsample4(p5) + self.lateral4(c4)
        p3 = self.upsample3(p4) + self.lateral3(c3)

        p6 = self.downsample5(p5)
        p7 = self.downsample6(p6)

        loc3 = self.loc_layer3(p3)
        conf_centerness3 = self.conf_centerness_layer3(p3)
        conf3, centerness3 = conf_centerness3.split([self.num_classes, 1], dim=1)

        loc4 = self.loc_layer4(p4)
        conf_centerness4 = self.conf_centerness_layer4(p4)
        conf4, centerness4 = conf_centerness4.split([self.num_classes, 1], dim=1)

        loc5 = self.loc_layer5(p5)
        conf_centerness5 = self.conf_centerness_layer5(p5)
        conf5, centerness5 = conf_centerness5.split([self.num_classes, 1], dim=1)

        loc6 = self.loc_layer6(p6)
        conf_centerness6 = self.conf_centerness_layer6(p6)
        conf6, centerness6 = conf_centerness6.split([self.num_classes, 1], dim=1)

        loc7 = self.loc_layer7(p7)
        conf_centerness7 = self.conf_centerness_layer7(p7)
        conf7, centerness7 = conf_centerness7.split([self.num_classes, 1], dim=1)

        locs = torch.cat([loc3.permute(0, 2, 3, 1).contiguous().view(loc3.size(0), -1),
                    loc4.permute(0, 2, 3, 1).contiguous().view(loc4.size(0), -1),
                    loc5.permute(0, 2, 3, 1).contiguous().view(loc5.size(0), -1),
                    loc6.permute(0, 2, 3, 1).contiguous().view(loc6.size(0), -1),
                    loc7.permute(0, 2, 3, 1).contiguous().view(loc7.size(0), -1)],dim=1)

        confs = torch.cat([conf3.permute(0, 2, 3, 1).contiguous().view(conf3.size(0), -1),
                           conf4.permute(0, 2, 3, 1).contiguous().view(conf4.size(0), -1),
                           conf5.permute(0, 2, 3, 1).contiguous().view(conf5.size(0), -1),
                           conf6.permute(0, 2, 3, 1).contiguous().view(conf6.size(0), -1),
                           conf7.permute(0, 2, 3, 1).contiguous().view(conf7.size(0), -1),], dim=1)

        centernesses = torch.cat([centerness3.permute(0, 2, 3, 1).contiguous().view(centerness3.size(0), -1),
                           centerness4.permute(0, 2, 3, 1).contiguous().view(centerness4.size(0), -1),
                           centerness5.permute(0, 2, 3, 1).contiguous().view(centerness5.size(0), -1),
                           centerness6.permute(0, 2, 3, 1).contiguous().view(centerness6.size(0), -1),
                           centerness7.permute(0, 2, 3, 1).contiguous().view(centerness7.size(0), -1), ], dim=1)

        out = (locs, confs, centernesses)
        return out

if __name__ == '__main__':
    model = FCOS()
    print(model)

    input = torch.randn(1, 3, 800, 1024)
    out = model(input)
    print(out[0].shape)
    print(out[1].shape)
    print(out[2].shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/743286
推荐阅读
相关标签
  

闽ICP备14008679号