当前位置:   article > 正文

MMDetection3D之DETR3D源码解析:整体流程篇_dist_train.sh

dist_train.sh

一、tools/dist_train.sh

  • 分布式训练脚本:其中配置文件为pillar.py,gpus为8
tools/dist_train.sh projects/configs/obj_dgcnn/pillar.py 8
  • 1
  • dist_train.sh脚本:
#!/usr/bin/env bash
# DETR3D传入config_path,gpus,port为默认
CONFIG=$1
GPUS=$2
PORT=${PORT:-28500}
# 这里的distributed为单机多卡训练模式,需要指定gpus,port,train.py,如果是多机多卡必须要指定节点个数与rank等参数
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

关于torch.distributed.launch的更多细节:https://blog.csdn.net/magic_ll/article/details/122359490

二、train.py

参数设置

设置config file和work dir,work dir保存最终config,log等信息,work dir默认为path/to/user/work_dir/

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--resume-from', help='the checkpoint file to resume from')
   '''
   省略一部分
   '''

    if args.options and args.cfg_options:
        raise ValueError(
            '--options and --cfg-options cannot be both specified, '
            '--options is deprecated in favor of --cfg-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --cfg-options')
        args.cfg_options = args.options

    return args
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

读入配置文件更新配置信息

作者将自定义的部分放在 ‘projects/mmdet3d_plugin/’ 文件夹下,通过registry类注册模块,这里利用importlib导入模块并初始化自定义的类。
在这里插入图片描述

	args = parse_args()
	cfg = Config.fromfile(args.config)
	# 从args更新读取的config文件,args优先级>cfg的优先级,args定义了cfg文件中没有定义的work_dir等参数,还有一部分需要覆盖cfg的参数
	if args.cfg_options is not None:
	   cfg.merge_from_dict(args.cfg_options)
	# import modules from plguin/xx, registry will be updated
	if hasattr(cfg, 'plugin'):
	   if cfg.plugin:
	       # 将plugin批量导入模型环境
	       # plugin_dir='projects/mmdet3d_plugin/'
	       import importlib
	       
	       if hasattr(cfg, 'plugin_dir'):
	           plugin_dir = cfg.plugin_dir
	           # _module_dir = 'projects/mmdet3d_plugin'
	           _module_dir = os.path.dirname(plugin_dir)
	           _module_dir = _module_dir.split('/')
	           _module_path = _module_dir[0]
	           # 将目录转化为python中库层级.的形式
	           for m in _module_dir[1:]:
	               _module_path = _module_path + '.' + m
	           
	           plg_lib = importlib.import_module(_module_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

初始化运行参数

这里设置模型的输出信息保存路径、gpus等模型的运行时环境参数

	# 加载old config
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        # re-set gpu_ids with distributed training mode
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config and save to work_dir
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    
    if cfg.model.type in ['EncoderDecoder3D']:
        logger_name = 'mmseg'
    else:
        logger_name = 'mmdet'
    logger = get_root_logger(
        log_file=log_file, log_level=cfg.log_level, name=logger_name)
    # meta:保存环境信息、随机种子等
    meta = dict()
    # log env info
    meta['env_info'] = env_info
    meta['config'] = cfg.pretty_text
    set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
    meta['exp_name'] = osp.basename(args.config)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45

初始化数据集和模型

这里初始化模型,初始化train_dataset和val_dataset

    model = build_model(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
    model.init_weights()

    # dataset initialization,input: pipeline,class_names,modality
    # 返回train Dataset用于后面的Dataloader
    datasets = [build_dataset(cfg.data.train)]
    # 设置cfg_workflow=[['train',1],['val',1]]:每train一个epoch后测试验证集:代码省略
    # set checkpoint:代码省略
    
    # 初始化config,hook,dataloader,runner, 然后运行runner开始按照workflow开始训练
    train_model(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        meta=meta)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

train_model的运行机制

这部分完成了DataLoader的初始化,runner和hooks的初始化,并且按照workflow运行runner。

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):

    cfg = compat_cfg(cfg)
    logger = get_root_logger(log_level=cfg.log_level)

    # prepare data loaders,dataset可能是列表,也可能是单独一个,因为workflow=2时包含val的dataset
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    # runner加载,default runner:EpochBasedRunner
    runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
        'type']
    
    train_dataloader_default_args = dict(
        samples_per_gpu=2,
        workers_per_gpu=2,
        # `num_gpus` will be ignored if distributed
        num_gpus=len(cfg.gpu_ids),
        dist=distributed,
        seed=cfg.seed,
        runner_type=runner_type,
        persistent_workers=False)
        
    # 更新dataloader的参数设置,结合上面的设置和configfile里面的设置
    train_loader_cfg = {
        **train_dataloader_default_args,
        **cfg.data.get('train_dataloader', {}) # update dataloder_cfg from cfg files, if there is no train_dataloader, set this to {}
    }
    
    # 创建dataloader
    data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]

    # put model on gpus
    model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)

    # build optimizer
    auto_scale_lr(cfg, distributed, logger)
    optimizer = build_optimizer(model, cfg.optimizer)

    runner = build_runner(
        cfg.runner,
        default_args=dict(
            model=model,
            optimizer=optimizer,
            work_dir=cfg.work_dir,
            logger=logger,
            meta=meta))
            
    # register training hook
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None),
        custom_hooks_config=cfg.get('custom_hooks', None))

    # register eval hooks
    if validate:
        val_dataloader_default_args = dict(
            samples_per_gpu=1,
            workers_per_gpu=2,
            dist=distributed,
            shuffle=False,
            persistent_workers=False)

        val_dataloader_args = {
            **val_dataloader_default_args,
            **cfg.data.get('val_dataloader', {})
        }
        # Support batch_size > 1 in validation

        val_dataloader = build_dataloader(val_dataset, **val_dataloader_args)
        eval_cfg = cfg.get('evaluation', {})
        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(
            eval_hook(val_dataloader, **eval_cfg), priority='LOW')
    # resume from last model
    resume_from = None
    if cfg.resume_from is None and cfg.get('auto_resume'):
        resume_from = find_latest_checkpoint(cfg.work_dir)
    if resume_from is not None:
        cfg.resume_from = resume_from
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    
    # run runner iteratively
    runner.run(data_loaders, cfg.workflow)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/337868
推荐阅读
  

闽ICP备14008679号