赞
踩
修改train中模型的加载方式
路径在ultralytics/yolo/engine/model.py中
01找到
def train(self, **kwargs):
02注释掉下面代码
self.trainer.model = self.model
03修改完整体的代码为
- # region def train-模型训练过程
- def train(self, **kwargs):
- """
- Trains the model on a given dataset.
- Args:
- **kwargs (Any): Any number of arguments representing the training configuration.
- """
- self._check_is_pytorch_model()
- if self.session: # Ultralytics HUB session
- if any(kwargs):
- LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
- kwargs = self.session.train_args
- check_pip_update_available()
- overrides = self.overrides.copy()
- overrides.update(kwargs)
- if kwargs.get('cfg'):
- LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
-
- # region 参数改为False
- overrides = yaml_load(check_yaml(kwargs['cfg']),append_filename=False)
-
- # endregion
- overrides['mode'] = 'train'
- if not overrides.get('data'):
- raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
- if overrides.get('resume'):
- overrides['resume'] = self.ckpt_path
- self.task = overrides.get('task') or self.task
- self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
-
-
- # region 从新初始化模型
- # if not overrides.get('resume'): # manually set model only if not resuming
- # self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
- # self.model = self.trainer.model
- # endregion
- # region #自己修改的 增加
-
- # self.trainer.model = self.model
-
- # endregion
-
- self.trainer.hub_session = self.session # attach optional HUB session
-
- # region #开始训练
- self.trainer.train()
- # Update model and cfg after training
- if RANK in (-1, 0):
- self.model, _ = attempt_load_one_weight(str(self.trainer.best))
- self.overrides = self.model.args
- self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
- # endregion
- # endregion
路径在ultralytics/yolo/engine/trainer.py
01找到check_resume方法
def check_resume(self):
02 注释掉下面代码
resume = self.args.resume
03替换为直接加载路径
resume = r'D:\learn\sdxx\mbjc\ultralytics\weights\last.pt'
03修改完整体的代码为
- def check_resume(self):
- """Check if resume checkpoint exists and update arguments accordingly."""
- # region #注释与新增
- # resume = self.args.resume
- resume = r'D:\learn\sdxx\mbjc\ultralytics\weights\last.pt'
- # endregion
-
- if resume:
- try:
- exists = isinstance(resume, (str, Path)) and Path(resume).exists()
- last = Path(check_file(resume) if exists else get_latest_run())
-
- # Check that resume data YAML exists, otherwise strip to force re-download of dataset
- ckpt_args = attempt_load_weights(last).args
- if not Path(ckpt_args['data']).exists():
- ckpt_args['data'] = self.args.data
-
- self.args = get_cfg(ckpt_args)
- self.args.model, resume = str(last), True # reinstate
- except Exception as e:
- raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
- "i.e. 'yolo train resume model=path/to/last.pt'") from e
- self.resume = resume
路径在ultralytics/yolo/engine/trainer.py
01找到
def resume_training(self, ckpt):
在下面加入
ckpt=torch.load('D:\learn\sdxx\mbjc\ultralytics\weights\last.pt')
整体的代码为
- def resume_training(self, ckpt):
-
- # region #新增加的
- ckpt=torch.load(r'D:\learn\sdxx\mbjc\ultralytics\weights\last.pt')
- # endregion
-
-
- """Resume YOLO training from given epoch and best fitness."""
- if ckpt is None:
- return
- best_fitness = 0.0
- start_epoch = ckpt['epoch'] + 1
- if ckpt['optimizer'] is not None:
- self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
- best_fitness = ckpt['best_fitness']
- if self.ema and ckpt.get('ema'):
- self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
- self.ema.updates = ckpt['updates']
- if self.resume:
- assert start_epoch > 0, \
- f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
- f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
- LOGGER.info(
- f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
- if self.epochs < start_epoch:
- LOGGER.info(
- f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
- self.epochs += ckpt['epoch'] # finetune additional epochs
- self.best_fitness = best_fitness
- self.start_epoch = start_epoch
- if start_epoch > (self.epochs - self.args.close_mosaic):
- LOGGER.info('Closing dataloader mosaic')
- if hasattr(self.train_loader.dataset, 'mosaic'):
- self.train_loader.dataset.mosaic = False
- if hasattr(self.train_loader.dataset, 'close_mosaic'):
- self.train_loader.dataset.close_mosaic(hyp=self.args)
我的是last.pt复制到D:\learn\sdxx\mbjc\ultralytics\weights\last.pt
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。