当前位置:   article > 正文

torch.save torch.load 四种使用方式 如何加载模型 如何加载模型参数 如何保存模型 如何保存模型参数

torch.load

在 PyTorch 中,我们可以使用 torch.save 函数将模型或张量保存到文件中,使用 torch.load 函数从文件中加载模型或张量。具体用法如下:

保存模型

	import torch
    # 定义模型
    model = ...
    # 保存模型
    torch.save(model.state_dict(), 'model.pth')
  • 1
  • 2
  • 3
  • 4
  • 5

在上面的代码中,我们使用 model.state_dict() 函数将模型的参数保存为一个字典,并使用 torch.save 函数将字典保存到名为 'model.pth' 的文件中。如果需要保存整个模型,可以使用 torch.save(model, 'model.pth') 函数保存模型。

加载模型

	import torch
    # 定义模型
    model = ...
    # 加载模型
    model.load_state_dict(torch.load('model.pth'))
  • 1
  • 2
  • 3
  • 4
  • 5

在上面的代码中,我们使用 torch.load 函数从名为 'model.pth' 的文件中加载模型的参数字典,并使用 model.load_state_dict 函数加载参数字典到模型中。如果需要加载整个模型,可以使用 model = torch.load('model.pth') 函数加载模型。

保存张量

	import torch
    # 定义张量
    tensor = ...
    # 保存张量
    torch.save(tensor, 'tensor.pth')
  • 1
  • 2
  • 3
  • 4
  • 5

在上面的代码中,我们使用 torch.save 函数将张量保存到名为 'tensor.pth' 的文件中。

加载张量

	import torch
    # 加载张量
    tensor = torch.load('tensor.pth')
  • 1
  • 2
  • 3

在上面的代码中,我们使用 torch.load 函数从名为 'tensor.pth' 的文件中加载张量。

如果使用 torch.save(model) 函数保存整个模型,可以使用 torch.load 函数直接加载整个模型。具体用法如下:

保存模型

	import torch
    # 定义模型
    model = ...
    # 保存模型
    torch.save(model, 'model.pth')
  • 1
  • 2
  • 3
  • 4
  • 5

在上面的代码中,我们使用 torch.save 函数将整个模型保存到名为 'model.pth' 的文件中。

加载模型

	import torch
    # 加载模型
    model = torch.load('model.pth')
  • 1
  • 2
  • 3

在上面的代码中,我们使用 torch.load 函数从名为 'model.pth' 的文件中加载整个模型。需要注意的是,如果模型是在 GPU 上训练的,加载模型时需要使用 map_location 参数将模型映射到 CPU 上:

	import torch
    # 加载模型
    model = torch.load('model.pth', map_location=torch.device('cpu'))
  • 1
  • 2
  • 3

如果模型是在 GPU 上训练的,而且需要将模型加载到指定的 GPU 上,可以使用 torch.cuda.device 函数切换到指定的 GPU,然后将模型加载到该 GPU 上:

	import torch
    # 切换到指定的 GPU
    torch.cuda.device(1)
    # 加载模型
    model = torch.load('model.pth', map_location=torch.device('cuda:1'))
  • 1
  • 2
  • 3
  • 4
  • 5

在上面的代码中,我们使用 torch.cuda.device 函数切换到索引为 1 的 GPU,然后将模型加载到该 GPU 上。

如果使用 torch.save(model) 函数保存模型,加载模型时可以使用 model.load_state_dict 函数只加载模型的参数。具体用法如下:

保存模型

	import torch
    # 定义模型
    model = ...
    # 保存模型
    torch.save(model, 'model.pth')
  • 1
  • 2
  • 3
  • 4
  • 5

在上面的代码中,我们使用 torch.save 函数将整个模型保存到名为 'model.pth' 的文件中。

加载模型参数

	import torch
    # 定义模型
    model = ...
    # 加载模型参数
    model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
  • 1
  • 2
  • 3
  • 4
  • 5

在上面的代码中,我们使用 torch.load 函数从名为 'model.pth' 的文件中加载整个模型,并使用 model.load_state_dict 函数将加载的参数字典加载到模型中。需要注意的是,如果模型是在 GPU 上训练的,加载模型时需要使用 map_location 参数将模型映射到 CPU 上。如果模型在 GPU 上训练并且需要加载到指定的 GPU 上,请参考前面的回答。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号