当前位置:   article > 正文

[PyTorch][PyTorch Lightning]:断点续训

[PyTorch][PyTorch Lightning]:断点续训

[PyTorch Lightning]:断点续训

要在 PyTorch Lightning 中从断点继续训练,可以使用以下步骤:

1. 保存断点

在训练过程中使用 ModelCheckpoint 回调来保存模型的状态。可以在 Trainer 中设置 checkpoint_callback 参数来使用该回调。

from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    filepath='model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    verbose=True,
    monitor='val_loss',
    mode='min'
)

trainer = Trainer(
    checkpoint_callback=checkpoint_callback,
    ...
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

ModelCheckpoint 回调可以保存以下量:

  1. 模型权重:保存模型的权重参数。

  2. Trainer 对象的状态:保存 Trainer 对象的状态,包括当前的 epoch 和 batch 等信息。

  3. 优化器状态:保存优化器的状态,包括当前的学习率和动量等信息。

  4. 日志:保存训练过程中的日志信息,包括损失、指标和训练时间等信息。

可以通过 ModelCheckpointsave_weights_onlysave_last 参数来控制是否保存模型的权重、最后一次保存的模型权重和 Trainer 对象的状态以及其他信息。例如,如果 save_weights_only=True,则只会保存模型的权重,而不会保存其他信息。

2. 断点恢复,重训

方法1: 使用 resume_from_checkpoint=‘’

在需要从断点继续训练的时候,创建一个新的 Trainer 对象,并使用 resume_from_checkpoint 参数来指定要继续训练的模型状态的路径。

trainer = Trainer(
    resume_from_checkpoint='path/to/checkpoint.ckpt',
    ...
)
  • 1
  • 2
  • 3
  • 4

使用新的 Trainer 对象重新运行 fit 函数即可从断点处继续训练模型。

trainer.fit(model, train_dataloader, val_dataloader)
  • 1

方法2: 使用 trainer.fit(ckpt_path=‘’)

由于在 PyTorch Lightning 版本 1.5 中,resume_from_checkpoint 参数已被弃用,并且在版本 2.0 中将被删除。新的替代方法是在 Trainerfit 方法中使用 ckpt_path 参数来指定要恢复训练的检查点路径。

例如,假设你有一个名为 model.ckpt 的检查点文件,你可以使用以下代码从该检查点处继续训练模型:

from pytorch_lightning import Trainer

trainer = Trainer(resume_from_checkpoint='model.ckpt')
trainer.fit(model, train_dataloader)
  • 1
  • 2
  • 3
  • 4

现在应该使用以下代码:

from pytorch_lightning import Trainer

trainer = Trainer(ckpt_path='model.ckpt')
trainer.fit(model, train_dataloader)
  • 1
  • 2
  • 3
  • 4

这将在 Trainer 中设置 ckpt_path 参数,以指示从该检查点处恢复训练。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/589915
推荐阅读
相关标签
  

闽ICP备14008679号