当前位置:   article > 正文

CenterNet目标检测学习记录_centernet sigma值设置

centernet sigma值设置

模型使用

按照centernet源码官方教程执行即可,我使用的是torch 1.4.0版本,dcn好像要替换(忘了…),替换部分参考链接,我跑通demo后,去看源码,发现如果用Hourglass backbone的话,不编译dcn也能正常使用的(之前在windows怎么编译都不行),自己把代码提取出来,定义自己的模型是hourglass backbone就行了(默认是dla34).

训练自己的数据

训练自己的参考这篇博客链接

模型详解

1.hourglass网络

参考这两篇博客,介绍得比较详细链接1链接2,houglass之间的连接部分如下.(每个hourglass都有输出,下图第二个hourglass边幅原因,没画出来,计算损失时要计算每个输出,预测时只取最后一个输出)
在这里插入图片描述hm [batch,ncls,H,W].wh和reg [batch,2,H,W],所以,如果两个物体中心刚好重合,实际只能得到一个预测框,不过作者也说了,coco数据集上这种情况不到千分一,就没处理.

2.heatmap

centernet里面有很多设置,我基本都使用默认设置

1)半径的确定

半径的大小确定来源cornernet,主要是cornernet靠近gt角点的目标框与标签还是有很高的IOU,因此一定范围内的损失权重跟远的负样本不一致.这个范围就是通过IOU来确定半径得到的邻域,半径的确定参考
链接(此处链接结果貌似不太对,看原理即可),严格来说不是用r来计算,应该是rcos/rsin,不过影响不大
但是官方源码那里是有问题的,参考链接,链接
修正如下

def gaussian_radius(det_size, min_overlap=0.7):
  height, width = det_size

  a1  = 1
  b1  = (height + width)
  c1  = width * height * (1 - min_overlap) / (1 + min_overlap)
  sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)

  #r1  = (b1 + sq1) / 2 #源代码是错的
  r1 = (b1 - sq1) / (2 * a1)

  a2  = 4
  b2  = 2 * (height + width)
  c2  = (1 - min_overlap) * width * height
  sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
  #r2  = (b2 + sq2) / 2
  r2 = (b2 - sq2) / (2 * a2)

  a3  = 4 * min_overlap
  b3  = -2 * min_overlap * (height + width)
  c3  = (min_overlap - 1) * width * height
  sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
  #r3  = (b3 + sq3) / 2
  r3 = (b3 + sq3) / (2 * a3)
  return min(r1, r2, r3)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

2)heatmap绘制

在这里插入图片描述
x,y值是相对圆心位置,centernet内 一种 sigma取值如下
在这里插入图片描述

3.损失

1)论文上的heatmap损失

在这里插入图片描述作者采用alpha=2,beta=4

def _neg_loss(pred, gt):
  ''' Modified focal loss. Exactly the same as CornerNet.
      Runs faster and costs a little bit more memory
    Arguments:
      pred (batch x c x h x w)
      gt_regr (batch x c x h x w)
  '''
  pos_inds = gt.eq(1).float()#equal比较函数,正例
  neg_inds = gt.lt(1).float()#反例

  neg_weights = torch.pow(1 - gt, 4)# beta=4 ,alpha=2

  loss = 0

  pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
  neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

  num_pos  = pos_inds.float().sum()
  pos_loss = pos_loss.sum()
  neg_loss = neg_loss.sum()

  if num_pos == 0:
    loss = loss - neg_loss
  else:
    loss = loss - (pos_loss + neg_loss) / num_pos
  return loss

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  1. wh,和偏移量 reg 默认是L1loss
    在这里插入图片描述
    在这里插入图片描述
class RegL1Loss(nn.Module):
  def __init__(self):
    super(RegL1Loss, self).__init__()
  
  def forward(self, output, mask, ind, target):
    pred = _transpose_and_gather_feat(output, ind) #[batch,maxobjs,2]
    mask = mask.unsqueeze(2).expand_as(pred).float()
    # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
    loss = F.l1_loss(pred * mask, target * mask, size_average=False)
    loss = loss / (mask.sum() + 1e-4)
    return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

总损失
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/122335
推荐阅读
相关标签
  

闽ICP备14008679号