当前位置:   article > 正文

YOLOV8在训练好的模型后增加epoch继续训练_yolov8 如何在原有模型上继续训练

yolov8 如何在原有模型上继续训练

1、首先在启动文件train.py中将resume设置为True,其他的设置随便设。

  1. from ultralytics import YOLO
  2.  
  3. model = YOLO('runs/detect/train/weights/last.pt')
  4. results = model.train(save=True, resume=True)

2、修改ultralytics/engine/trainer.py文件

        在ultralytics/engine/trainer.py中找到def check_resume(self),将resume = self.args.resume替换为resume = 'runs/detect/train/weights/last.pt';这个路径是你想要继续训练的权重文件。

  1. def check_resume(self):
  2. ###### 修改处 ###############
  3. # resume = self.args.resume
  4. resume = 'runs/detect/train/weights/last.pt';
  5. ######################################

        再找到def resume_training(self, ckpt):在第一行添加

        ckpt =torch.load('runs/detect/train/weights/last.pt')

        将start_epoch = ckpt['epoch'] + 1修改为上次训练的epoch数就可以,比如说上次训练了100次这次想继续训练50次就改为100

  1. def resume_training(self, ckpt):
  2. """Resume YOLO training from given epoch and best fitness."""
  3. ###### 修改处 ###############
  4. ckpt = torch.load('runs/detect/train/weights/last.pt')
  5. ######################################
  6. if ckpt is None:
  7. return
  8. best_fitness = 0.0
  9. ###### 修改处 ###############
  10. #start_epoch = ckpt['epoch'] + 1
  11. start_epoch = 100
  12. ######################################

        找到BaseTrainer下的

def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):

将self.epochs = self.args.epochs修改为这次要训练的epochs,比如说上次训练了100次这次想继续训练50次就改为150

  1. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  2. """
  3. Initializes the BaseTrainer class.
  4. Args:
  5. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
  6. overrides (dict, optional): Configuration overrides. Defaults to None.
  7. """
  8. self.args = get_cfg(cfg, overrides)
  9. self.check_resume(overrides)
  10. self.device = select_device(self.args.device, self.args.batch)
  11. self.validator = None
  12. self.model = None
  13. self.metrics = None
  14. self.plots = {}
  15. init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
  16. # Dirs
  17. self.save_dir = get_save_dir(self.args)
  18. self.wdir = self.save_dir / 'weights' # weights dir
  19. if RANK in (-1, 0):
  20. self.wdir.mkdir(parents=True, exist_ok=True) # make dir
  21. self.args.save_dir = str(self.save_dir)
  22. yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
  23. self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
  24. self.save_period = self.args.save_period
  25. self.batch_size = self.args.batch
  26. ###### 修改处 ###############
  27. #self.epochs = self.args.epochs
  28. self.epochs = 150
  29. ######################################
  30. self.start_epoch = 0

3、修改ultralytics/engine/model.py文件

在model.py文件中找到def train(self, trainer=None, **kwargs):进行如下修改

  1. def train(self, trainer=None, **kwargs):
  2. """
  3. Trains the model on a given dataset.
  4. Args:
  5. trainer (BaseTrainer, optional): Customized trainer.
  6. **kwargs (Any): Any number of arguments representing the training configuration.
  7. """
  8. self._check_is_pytorch_model()
  9. if self.session: # Ultralytics HUB session
  10. if any(kwargs):
  11. LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
  12. kwargs = self.session.train_args
  13. check_pip_update_available()
  14. overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
  15. custom = {'data': TASK2DATA[self.task]} # method defaults
  16. args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
  17. if args.get('resume'):
  18. args['resume'] = self.ckpt_path
  19. self.trainer = (trainer or self.smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
  20. if not args.get('resume'): # manually set model only if not resuming
  21. ###### 修改处 ###############
  22. # self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
  23. # self.model = self.trainer.model
  24. self.trainer.model = self.model
  25. ######################################
  26. self.trainer.hub_session = self.session # attach optional HUB session
  27. self.trainer.train()
  28. # Update model and cfg after training
  29. if RANK in (-1, 0):
  30. ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
  31. self.model, _ = attempt_load_one_weight(ckpt)
  32. self.overrides = self.model.args
  33. self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
  34. return self.metrics

最后运行train.py即可

!训练完成后把代码改回去 !

4、如果因为patience太小提前结束的训练,想要继续训练到指定epoch数只需要再修改patience即可,将patience修改为一个较大的值,这里我设置为300。

        ultralytics/engine/trainer.py:

  1. def _setup_train(self, world_size):
  2. """
  3. Builds dataloaders and optimizer on correct rank process.
  4. """
  5. # Model
  6. self.run_callbacks('on_pretrain_routine_start')
  7. ckpt = self.setup_model()
  8. self.model = self.model.to(self.device)
  9. self.set_model_attributes()
  10. # Freeze layers
  11. freeze_list = self.args.freeze if isinstance(
  12. self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
  13. always_freeze_names = ['.dfl'] # always freeze these layers
  14. freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
  15. for k, v in self.model.named_parameters():
  16. # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
  17. if any(x in k for x in freeze_layer_names):
  18. LOGGER.info(f"Freezing layer '{k}'")
  19. v.requires_grad = False
  20. elif not v.requires_grad:
  21. LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
  22. 'See ultralytics.engine.trainer for customization of frozen layers.')
  23. v.requires_grad = True
  24. # Check AMP
  25. self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
  26. if self.amp and RANK in (-1, 0): # Single-GPU and DDP
  27. callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
  28. self.amp = torch.tensor(check_amp(self.model), device=self.device)
  29. callbacks.default_callbacks = callbacks_backup # restore callbacks
  30. if RANK > -1 and world_size > 1: # DDP
  31. dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
  32. self.amp = bool(self.amp) # as boolean
  33. self.scaler = amp.GradScaler(enabled=self.amp)
  34. if world_size > 1:
  35. self.model = DDP(self.model, device_ids=[RANK])
  36. # Check imgsz
  37. gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
  38. self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
  39. # Batch size
  40. if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
  41. self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
  42. # Dataloaders
  43. batch_size = self.batch_size // max(world_size, 1)
  44. self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
  45. if RANK in (-1, 0):
  46. self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
  47. self.validator = self.get_validator()
  48. metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
  49. self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
  50. self.ema = ModelEMA(self.model)
  51. if self.args.plots:
  52. self.plot_training_labels()
  53. # Optimizer
  54. self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
  55. weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
  56. iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
  57. self.optimizer = self.build_optimizer(model=self.model,
  58. name=self.args.optimizer,
  59. lr=self.args.lr0,
  60. momentum=self.args.momentum,
  61. decay=weight_decay,
  62. iterations=iterations)
  63. # Scheduler
  64. if self.args.cos_lr:
  65. self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
  66. else:
  67. self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
  68. self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
  69. ###### 修改处 ###############
  70. #self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
  71. self.stopper, self.stop = EarlyStopping(patience=300), False
  72. ######################################
  73. self.resume_training(ckpt)
  74. self.scheduler.last_epoch = self.start_epoch - 1 # do not move
  75. self.run_callbacks('on_pretrain_routine_end')

最后运行train.py即可

!训练完成后把代码改回去 !

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/418082
推荐阅读
相关标签
  

闽ICP备14008679号