当前位置:   article > 正文

每天讲解一点PyTorch 【15】model.load_state_dict torch.load torch.save

每天讲解一点PyTorch 【15】model.load_state_dict torch.load torch.save

今天我们讲解:

state_dict = torch.load('checkpoint.pt')
#或者
state_dict = torch.load('checkpoint.pth') #torch.load加载**模型参数**
model.load_state_dict(state_dict) #把模型参数加载到模型中

model.cuda()
model.eval() #model.eval()关闭Batch Normalization和Dropout层
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
#加载模型结构和模型参数
model = torch.load(path)
output = model(x)
  • 1
  • 2
  • 3

torch.save(model.state_dict(), ‘checkpoint.pt’) #仅保存模型参数
torch.save(model,‘checkpoint.pt’) #保存模型结构和模型参数

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/785883
推荐阅读
相关标签
  

闽ICP备14008679号