当前位置:   article > 正文

Faster RCNN代码详解(一):算法整体结构

faster rcnn代码

本系列博客介绍Faster RCNN算法的细节,以MXNet框架的代码为例。希望可以通过该系列博客让更多同学了解Faster RCNN算法中关于RPN网络的构建、anchor、proposal、损失函数的定义、正负样本的定义等细节,这样对于理解Faster RCNN后续的延伸版本(比如R-FCN、FPN、Mask RCNN)以及其他object detection算法也有一定的帮助。接下来的讲解基本上按照训练代码的顺序进行。

项目地址:https://github.com/precedenceguo/mx-rcnn

该系列博客以端到端(end to end)的训练方式为例来介绍Faster RCNN算法,训练代码所在脚本:~/mx-rcnn/train_end2end.py,该脚本构建了算法的整体结构,主要包含网络结构的构建(以特征提取主网络采用resnet为例,通过~mx-rcnn/rcnn/symbol/symbol_resnet.py脚本的get_resnet_train函数构建)和数据读取(通过~mx-rcnn/rcnn/core/loader.py脚本的AnchorLoader类和~/mx-rcnn/rcnn/io/rpn.py脚本的assign_anchor函数进行读取,前者会调用后者)两部分。

接下来就来看看训练启动脚本:~/mx-rcnn/train_end2end.py的代码细节:

import argparse
import pprint
import mxnet as mx
import numpy as np

from rcnn.logger import logger
from rcnn.config import config, default, generate_config
from rcnn.symbol import *
from rcnn.core import callback, metric
from rcnn.core.loader import AnchorLoader
from rcnn.core.module import MutableModule
from rcnn.utils.load_data import load_gt_roidb, merge_roidb, filter_roidb
from rcnn.utils.load_model import load_param

def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch,
              lr=0.001, lr_step='5'):
    # setup config
    config.TRAIN.BATCH_IMAGES = 1
    config.TRAIN.BATCH_ROIS = 128
    config.TRAIN.END2END = True
    config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True

    # load symbol
# eval语句是执行字符串命令,以args.network为resnet为例,就是调用~mx-rcnn/rcnn/symbol/symbol_resnet.py
# 脚本中的get_resnet_train函数来得到Faster RCNN的网络结构。
# 在该函数中涉及具体的RPN网络、RPN网络的损失函数、检测网络、检测网络的损失函数细节。
    sym = eval('get_' + args.network + '_train')(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    logger.info(pprint.pformat(config))

    # load dataset and prepare imdb for training
# 这部分是从yml文件读取标注信息,主要调用的接口是load_gt_roidb函数
    image_sets = [iset for iset in args.image_set.split('+')]
    roidbs = [load_gt_roidb(args.dataset, image_set, args.root_path, args.dataset_path, flip=not args.no_flip)
              for image_set in image_sets]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb)

    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号