当前位置:   article > 正文

Faster RCNN代码详解(五):关于检测网络(Fast RCNN)的proposal_faster rcnn proposal

faster rcnn proposal

Faster RCNN代码详解(二):网络结构构建中介绍了Faster RCNN算法的网络结构,其中有一个用于生成ROI proposal target的自定义层,该自定义层的输出作为检测网络(Fast RCNN)的输入,这篇博客就来介绍这个自定义层的内容。

该自定义层的实现所在脚本~/mx-rcnn/rcnn/symbol/proposal_target.py,该层返回的group列表包含4个值,分别是rois,label,bbox_target,bbox_weight。roi用于ROI Pooling层,label用于检测网络的分类支路、bbox_target和bbox_weight用于检测网络的回归支路

通过系列四中对RPN网络中anchor的介绍,你应该明白这里的label、bbox_target、bbox_weight和RPN网络中的不同,RPN网络中的label、bbox_target和bbox_weight等变量的定义方式和这里不同,同时在RPN网络中那边变量是服务于anchor的。

"""
Proposal Target Operator selects foreground and background roi and assigns label, bbox_transform to them.
"""

import logging
import mxnet as mx
import numpy as np
from distutils.util import strtobool

from ..logger import logger
from rcnn.io.rcnn import sample_rois

class ProposalTargetOperator(mx.operator.CustomOp):
    def __init__(self, num_classes, batch_images, batch_rois, fg_fraction):
        super(ProposalTargetOperator, self).__init__()
        self._num_classes = num_classes
        self._batch_images = batch_images
        self._batch_rois = batch_rois
        self._fg_fraction = fg_fraction

        if logger.level == logging.DEBUG:
            self._count = 0
            self._fg_num = 0
            self._bg_num = 0

    def forward(self, is_train, req, in_data, out_data, aux):
        assert self._batch_rois % self._batch_images == 0, \
            'BATCHIMAGES {} must devide BATCH_ROIS {}'.format(self._batch_images, self._batch_rois)
        rois_per_image = int(self._batch_rois / self._batch_images)
        fg_rois_per_image = int(round(self._fg_fraction * rois_per_image))


# all_rois的维度是(2000,5),不过all_rois除了4列坐标外,剩下一列全是0,
# 并不表示roi的标签,仅仅是batch的index标识。gt_boxes的维度是(x,5),x是object的数量。
        all_rois = in_data[0].asnumpy()
        gt_boxes = in_data[1].asnumpy()

        # Include ground-truth boxes in the set of candidate rois
# 初始化的zeros替换掉gt_boxes中object的类别,然后和原来的all_rois做合并,
# 最后得到的all_rois的维度是(2000+x,5)。因为all_rois变量中并不需要ground truth的标签,
# 所以都用0值替代。从最后的assert语句也可以看出第一列0值的含义是和batch相关。
        zeros = np.zeros((gt_boxes.shape[0], 1), dtype=gt_boxes.dtype)
        all_rois = np.vstack((all_rois, np.hstack((zeros, gt_boxes[:, :-1]))))
        # Sanity check: single batch only
        assert</
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/250102
推荐阅读
相关标签
  

闽ICP备14008679号