当前位置:   article > 正文

torch模型保存_torch 保存模型

torch 保存模型

torch模型保存。

模型保存的本质就是利用pickle模块进行序列化。序列化到文件,从文件反序列化回来的对象,要么是Python自定义的对象,要么是本文件中已经定义的类。

import torch
import torch.nn as nn
import torch.optim as optim


class Model(nn.Module):

    def __init__(self, input_size, output_size):

        super(Model, self).__init__()
        self.linear1 = nn.Linear(input_size, input_size * 2)
        self.linear2 = nn.Linear(input_size * 2, output_size)

    def forward(self, inputs):

        inputs = self.linear1(inputs)
        output = self.linear2(inputs)
        return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

第一种方式

model = Model()
torch.save(model,'./model.pth')

model = torch.load('./model.pth')
  • 1
  • 2
  • 3
  • 4

第二种方式

model = Model()
torch.save(model.state_dict(), './model_state_dict.pth')

model = Model()
model.load_state_dict('./model_state_dict.pth')
  • 1
  • 2
  • 3
  • 4
  • 5
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/772757
推荐阅读
相关标签
  

闽ICP备14008679号