赞
踩
当模型在训练过程中时,随着其不断接收更多数据,其性能也会发生变化。在训练过程中保存模型的状态是一种最佳实践。这样可以在开发模型的过程中,在每个关键点上获得模型的一个版本,即一个检查点。一旦训练完成,您可以使用在训练过程中找到的性能最佳的检查点。
检查点还使得训练在中断的情况下可以从中断的地方恢复。
PyTorch Lightning 检查点在普通的 PyTorch 中完全可用。
一个 Lightning 检查点包含了模型的整个内部状态的转储。与普通的 PyTorch 不同,Lightning 保存了你在最复杂的分布式训练环境中恢复模型所需的一切。
在 Lightning 检查点中,您会找到:
nn.Module 的模型权重,具体使用方法如下。
Lightning checkpoints 完全兼容普通的 torch nn.Modules。
checkpoint = torch.load(CKPT_PATH)
print(checkpoint.keys())
例如,假设像下面这样创建了一个 LightningModule:
class Encoder(nn.Module): ... class Decoder(nn.Module): ... class Autoencoder(L.LightningModule): def __init__(self, encoder, decoder, *args, **kwargs): super().__init__() self.encoder = encoder self.decoder = decoder autoencoder = Autoencoder(Encoder(), Decoder())
一旦autoencoder训练完成,就可以提取出与 torch nn.Module 相关的权重。
checkpoint = torch.load(CKPT_PATH)
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}
官方文档:https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。