赞
踩
tools/dist_train.sh projects/configs/obj_dgcnn/pillar.py 8
#!/usr/bin/env bash
# DETR3D传入config_path,gpus,port为默认
CONFIG=$1
GPUS=$2
PORT=${PORT:-28500}
# 这里的distributed为单机多卡训练模式,需要指定gpus,port,train.py,如果是多机多卡必须要指定节点个数与rank等参数
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
关于torch.distributed.launch的更多细节:https://blog.csdn.net/magic_ll/article/details/122359490
设置config file和work dir,work dir保存最终config,log等信息,work dir默认为path/to/user/work_dir/
def parse_args(): parser = argparse.ArgumentParser(description='Train a detector') parser.add_argument('config', help='train config file path') parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--resume-from', help='the checkpoint file to resume from') ''' 省略一部分 ''' if args.options and args.cfg_options: raise ValueError( '--options and --cfg-options cannot be both specified, ' '--options is deprecated in favor of --cfg-options') if args.options: warnings.warn('--options is deprecated in favor of --cfg-options') args.cfg_options = args.options return args
作者将自定义的部分放在 ‘projects/mmdet3d_plugin/’ 文件夹下,通过registry类注册模块,这里利用importlib导入模块并初始化自定义的类。
args = parse_args() cfg = Config.fromfile(args.config) # 从args更新读取的config文件,args优先级>cfg的优先级,args定义了cfg文件中没有定义的work_dir等参数,还有一部分需要覆盖cfg的参数 if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # import modules from plguin/xx, registry will be updated if hasattr(cfg, 'plugin'): if cfg.plugin: # 将plugin批量导入模型环境 # plugin_dir='projects/mmdet3d_plugin/' import importlib if hasattr(cfg, 'plugin_dir'): plugin_dir = cfg.plugin_dir # _module_dir = 'projects/mmdet3d_plugin' _module_dir = os.path.dirname(plugin_dir) _module_dir = _module_dir.split('/') _module_path = _module_dir[0] # 将目录转化为python中库层级.的形式 for m in _module_dir[1:]: _module_path = _module_path + '.' + m plg_lib = importlib.import_module(_module_path)
这里设置模型的输出信息保存路径、gpus等模型的运行时环境参数
# 加载old config if args.resume_from is not None: cfg.resume_from = args.resume_from if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids else: cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) if args.autoscale_lr: # apply the linear scaling rule (https://arxiv.org/abs/1706.02677) cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8 # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # re-set gpu_ids with distributed training mode _, world_size = get_dist_info() cfg.gpu_ids = range(world_size) # create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # dump config and save to work_dir cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, f'{timestamp}.log') if cfg.model.type in ['EncoderDecoder3D']: logger_name = 'mmseg' else: logger_name = 'mmdet' logger = get_root_logger( log_file=log_file, log_level=cfg.log_level, name=logger_name) # meta:保存环境信息、随机种子等 meta = dict() # log env info meta['env_info'] = env_info meta['config'] = cfg.pretty_text set_random_seed(args.seed, deterministic=args.deterministic) cfg.seed = args.seed meta['seed'] = args.seed meta['exp_name'] = osp.basename(args.config)
这里初始化模型,初始化train_dataset和val_dataset
model = build_model( cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) model.init_weights() # dataset initialization,input: pipeline,class_names,modality # 返回train Dataset用于后面的Dataloader datasets = [build_dataset(cfg.data.train)] # 设置cfg_workflow=[['train',1],['val',1]]:每train一个epoch后测试验证集:代码省略 # set checkpoint:代码省略 # 初始化config,hook,dataloader,runner, 然后运行runner开始按照workflow开始训练 train_model( model, datasets, cfg, distributed=distributed, validate=(not args.no_validate), timestamp=timestamp, meta=meta)
这部分完成了DataLoader的初始化,runner和hooks的初始化,并且按照workflow运行runner。
def train_detector(model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None): cfg = compat_cfg(cfg) logger = get_root_logger(log_level=cfg.log_level) # prepare data loaders,dataset可能是列表,也可能是单独一个,因为workflow=2时包含val的dataset dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] # runner加载,default runner:EpochBasedRunner runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[ 'type'] train_dataloader_default_args = dict( samples_per_gpu=2, workers_per_gpu=2, # `num_gpus` will be ignored if distributed num_gpus=len(cfg.gpu_ids), dist=distributed, seed=cfg.seed, runner_type=runner_type, persistent_workers=False) # 更新dataloader的参数设置,结合上面的设置和configfile里面的设置 train_loader_cfg = { **train_dataloader_default_args, **cfg.data.get('train_dataloader', {}) # update dataloder_cfg from cfg files, if there is no train_dataloader, set this to {} } # 创建dataloader data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] # put model on gpus model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) # build optimizer auto_scale_lr(cfg, distributed, logger) optimizer = build_optimizer(model, cfg.optimizer) runner = build_runner( cfg.runner, default_args=dict( model=model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta)) # register training hook runner.register_training_hooks( cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config, cfg.get('momentum_config', None), custom_hooks_config=cfg.get('custom_hooks', None)) # register eval hooks if validate: val_dataloader_default_args = dict( samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False, persistent_workers=False) val_dataloader_args = { **val_dataloader_default_args, **cfg.data.get('val_dataloader', {}) } # Support batch_size > 1 in validation val_dataloader = build_dataloader(val_dataset, **val_dataloader_args) eval_cfg = cfg.get('evaluation', {}) eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook( eval_hook(val_dataloader, **eval_cfg), priority='LOW') # resume from last model resume_from = None if cfg.resume_from is None and cfg.get('auto_resume'): resume_from = find_latest_checkpoint(cfg.work_dir) if resume_from is not None: cfg.resume_from = resume_from if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: runner.load_checkpoint(cfg.load_from) # run runner iteratively runner.run(data_loaders, cfg.workflow)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。