赞
踩
使用预训练的模型分两种情况,一种是pytorch训练的,lightning的模型也包括其中,另外一种是第三方的。第三方的详见对应的使用说明。
LightningModules 也是nn.Module的子类,可以直接使用
调用load_from_checkpoint()方法,它是LightningModule中实现的一个方法。
class Encoder(torch.nn.Module): ... class AutoEncoder(LightningModule): def __init__(self): self.encoder = Encoder() self.decoder = Decoder() class CIFAR10Classifier(LightningModule): def __init__(self): # init the pretrained LightningModule # 加载预训练权重load_from_checkpoint self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH) self.feature_extractor.freeze() # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes self.classifier = nn.Linear(100, 10) def forward(self, x): representations = self.feature_extractor(x) x = self.classifier(representations) ...
训练时可以手动/自动保存多个检查点(比如last.ckpt, best.ckpt),可以用于后续的测试、恢复训练等。每个checkpoint 包含以下参数:
Global step 已运行的training_step数
LightningModule’s state_dict 模型权重
State of all optimizers 优化器状态
State of all learning rate schedulers 学习率状态
State of all callbacks (for stateful callbacks) 回调内的状态
State of datamodule (for stateful datamodules)
The hyperparameters (init arguments) with which the model was created 模型的超参数
The hyperparameters (init arguments) with which the datamodule was created 数据集的超参数
State of Loops
Trainer 默认会自动保存,如果有callback则按其设置的条件保存,比如选择self.log记录的最高精度(如val/acc,val/f1)保存,n个epochs或step后自动保存。
# 指定目录每个epoch后保存
trainer = Trainer(default_root_dir="some/path/")
调用load_from_checkpoint方法。
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
# disable randomness, dropout, etc...
model.eval()
# predict with the model
y_hat = model(x)
fit()时设置ckpt_path指向checkpoint文件路径
model = LitModel()
trainer = Trainer()
# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。