赞
踩
torch模型保存。
模型保存的本质就是利用pickle模块进行序列化。序列化到文件,从文件反序列化回来的对象,要么是Python自定义的对象,要么是本文件中已经定义的类。
import torch import torch.nn as nn import torch.optim as optim class Model(nn.Module): def __init__(self, input_size, output_size): super(Model, self).__init__() self.linear1 = nn.Linear(input_size, input_size * 2) self.linear2 = nn.Linear(input_size * 2, output_size) def forward(self, inputs): inputs = self.linear1(inputs) output = self.linear2(inputs) return output
第一种方式
model = Model()
torch.save(model,'./model.pth')
model = torch.load('./model.pth')
第二种方式
model = Model()
torch.save(model.state_dict(), './model_state_dict.pth')
model = Model()
model.load_state_dict('./model_state_dict.pth')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。