赞
踩
本系列博客介绍Faster RCNN算法的细节,以MXNet框架的代码为例。希望可以通过该系列博客让更多同学了解Faster RCNN算法中关于RPN网络的构建、anchor、proposal、损失函数的定义、正负样本的定义等细节,这样对于理解Faster RCNN后续的延伸版本(比如R-FCN、FPN、Mask RCNN)以及其他object detection算法也有一定的帮助。接下来的讲解基本上按照训练代码的顺序进行。
项目地址:https://github.com/precedenceguo/mx-rcnn
该系列博客以端到端(end to end)的训练方式为例来介绍Faster RCNN算法,训练代码所在脚本:~/mx-rcnn/train_end2end.py,该脚本构建了算法的整体结构,主要包含网络结构的构建(以特征提取主网络采用resnet为例,通过~mx-rcnn/rcnn/symbol/symbol_resnet.py脚本的get_resnet_train函数构建)和数据读取(通过~mx-rcnn/rcnn/core/loader.py脚本的AnchorLoader类和~/mx-rcnn/rcnn/io/rpn.py脚本的assign_anchor函数进行读取,前者会调用后者)两部分。
接下来就来看看训练启动脚本:~/mx-rcnn/train_end2end.py的代码细节:
import argparse
import pprint
import mxnet as mx
import numpy as np
from rcnn.logger import logger
from rcnn.config import config, default, generate_config
from rcnn.symbol import *
from rcnn.core import callback, metric
from rcnn.core.loader import AnchorLoader
from rcnn.core.module import MutableModule
from rcnn.utils.load_data import load_gt_roidb, merge_roidb, filter_roidb
from rcnn.utils.load_model import load_param
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch,
lr=0.001, lr_step='5'):
# setup config
config.TRAIN.BATCH_IMAGES = 1
config.TRAIN.BATCH_ROIS = 128
config.TRAIN.END2END = True
config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True
# load symbol
# eval语句是执行字符串命令,以args.network为resnet为例,就是调用~mx-rcnn/rcnn/symbol/symbol_resnet.py
# 脚本中的get_resnet_train函数来得到Faster RCNN的网络结构。
# 在该函数中涉及具体的RPN网络、RPN网络的损失函数、检测网络、检测网络的损失函数细节。
sym = eval('get_' + args.network + '_train')(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS)
feat_sym = sym.get_internals()['rpn_cls_score_output']
# setup multi-gpu
batch_size = len(ctx)
input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size
# print config
logger.info(pprint.pformat(config))
# load dataset and prepare imdb for training
# 这部分是从yml文件读取标注信息,主要调用的接口是load_gt_roidb函数
image_sets = [iset for iset in args.image_set.split('+')]
roidbs = [load_gt_roidb(args.dataset, image_set, args.root_path, args.dataset_path, flip=not args.no_flip)
for image_set in image_sets]
roidb = merge_roidb(roidbs)
roidb = filter_roidb(roidb)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。