当前位置:   article > 正文

python神经网络模型怎么保存_Pytorch模型的保存和提取以及保存和提取神经网络 - pytorch中文网...

pytorch 保存指标最好的模型

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下pytorch序列化语义以及我们保存和提取回归的神经网络.中文文档地址为:序列化语义

基本的pytorch保存和加载模型

保存和提取主要使用torch.save和torch.load方法实现保存和提取

import torch

test_data = torch.FloatTensor(2,3)

# 保存数据

torch.save(test_data, "test_data.pkl")

print test_data

# 提取数据

print torch.load("test_data.pkl")

保存和提取神经网络

第一种:只保存和加载模型参数(推荐使用)

# 保存

torch.save(the_model.state_dict(), PATH)

# 提取

the_model = TheModelClass(*args, **kwargs)

the_model.load_state_dict(torch.load(PATH))

使用实例:

net = torch.nn.Sequential(

torch.nn.Linear(1, 10),

torch.nn.ReLU(),

torch.nn.Linear(10, 1)

)

torch.save(net.state_dict(), "net_params.pkl")

# 这种方式将会提取整个神经网络, 网络大的时候可能会比较慢.

print torch.load('net_pa

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/439350
推荐阅读
相关标签
  

闽ICP备14008679号