当前位置:   article > 正文

YOLOV8进行resume接续训练_yolov8 resume

yolov8 resume

01 修改train脚本

        修改train中模型的加载方式

路径在ultralytics/yolo/engine/model.py中

01找到

 def train(self, **kwargs):

02注释掉下面代码

self.trainer.model = self.model

03修改完整体的代码为

  1. # region def train-模型训练过程
  2. def train(self, **kwargs):
  3. """
  4. Trains the model on a given dataset.
  5. Args:
  6. **kwargs (Any): Any number of arguments representing the training configuration.
  7. """
  8. self._check_is_pytorch_model()
  9. if self.session: # Ultralytics HUB session
  10. if any(kwargs):
  11. LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
  12. kwargs = self.session.train_args
  13. check_pip_update_available()
  14. overrides = self.overrides.copy()
  15. overrides.update(kwargs)
  16. if kwargs.get('cfg'):
  17. LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
  18. # region 参数改为False
  19. overrides = yaml_load(check_yaml(kwargs['cfg']),append_filename=False)
  20. # endregion
  21. overrides['mode'] = 'train'
  22. if not overrides.get('data'):
  23. raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
  24. if overrides.get('resume'):
  25. overrides['resume'] = self.ckpt_path
  26. self.task = overrides.get('task') or self.task
  27. self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
  28. # region 从新初始化模型
  29. # if not overrides.get('resume'): # manually set model only if not resuming
  30. # self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
  31. # self.model = self.trainer.model
  32. # endregion
  33. # region #自己修改的 增加
  34. # self.trainer.model = self.model
  35. # endregion
  36. self.trainer.hub_session = self.session # attach optional HUB session
  37. # region #开始训练
  38. self.trainer.train()
  39. # Update model and cfg after training
  40. if RANK in (-1, 0):
  41. self.model, _ = attempt_load_one_weight(str(self.trainer.best))
  42. self.overrides = self.model.args
  43. self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
  44. # endregion
  45. # endregion

02 修改check_resume代码

        路径在ultralytics/yolo/engine/trainer.py

01找到check_resume方法

def check_resume(self):

02 注释掉下面代码

resume = self.args.resume

03替换为直接加载路径

resume = r'D:\learn\sdxx\mbjc\ultralytics\weights\last.pt'

03修改完整体的代码为

  1. def check_resume(self):
  2. """Check if resume checkpoint exists and update arguments accordingly."""
  3. # region #注释与新增
  4. # resume = self.args.resume
  5. resume = r'D:\learn\sdxx\mbjc\ultralytics\weights\last.pt'
  6. # endregion
  7. if resume:
  8. try:
  9. exists = isinstance(resume, (str, Path)) and Path(resume).exists()
  10. last = Path(check_file(resume) if exists else get_latest_run())
  11. # Check that resume data YAML exists, otherwise strip to force re-download of dataset
  12. ckpt_args = attempt_load_weights(last).args
  13. if not Path(ckpt_args['data']).exists():
  14. ckpt_args['data'] = self.args.data
  15. self.args = get_cfg(ckpt_args)
  16. self.args.model, resume = str(last), True # reinstate
  17. except Exception as e:
  18. raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
  19. "i.e. 'yolo train resume model=path/to/last.pt'") from e
  20. self.resume = resume

 03 修改def resume_training代码

          路径在ultralytics/yolo/engine/trainer.py

01找到

def resume_training(self, ckpt):

在下面加入

ckpt=torch.load('D:\learn\sdxx\mbjc\ultralytics\weights\last.pt')

整体的代码为

  1. def resume_training(self, ckpt):
  2. # region #新增加的
  3. ckpt=torch.load(r'D:\learn\sdxx\mbjc\ultralytics\weights\last.pt')
  4. # endregion
  5. """Resume YOLO training from given epoch and best fitness."""
  6. if ckpt is None:
  7. return
  8. best_fitness = 0.0
  9. start_epoch = ckpt['epoch'] + 1
  10. if ckpt['optimizer'] is not None:
  11. self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
  12. best_fitness = ckpt['best_fitness']
  13. if self.ema and ckpt.get('ema'):
  14. self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
  15. self.ema.updates = ckpt['updates']
  16. if self.resume:
  17. assert start_epoch > 0, \
  18. f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
  19. f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
  20. LOGGER.info(
  21. f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
  22. if self.epochs < start_epoch:
  23. LOGGER.info(
  24. f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
  25. self.epochs += ckpt['epoch'] # finetune additional epochs
  26. self.best_fitness = best_fitness
  27. self.start_epoch = start_epoch
  28. if start_epoch > (self.epochs - self.args.close_mosaic):
  29. LOGGER.info('Closing dataloader mosaic')
  30. if hasattr(self.train_loader.dataset, 'mosaic'):
  31. self.train_loader.dataset.mosaic = False
  32. if hasattr(self.train_loader.dataset, 'close_mosaic'):
  33. self.train_loader.dataset.close_mosaic(hyp=self.args)

04 复制最后的权重到相应路径

我的是last.pt复制到D:\learn\sdxx\mbjc\ultralytics\weights\last.pt

05 点击训练即可开始正常训练

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/312830
推荐阅读
相关标签
  

闽ICP备14008679号