当前位置:   article > 正文

torch.load、torch.save、torch.optim.Adam的用法_torch.save用法

torch.save用法

目录

一、保存模型-torch.save()

1.只保存model的权重

2.保存多项内容

二、加载模型-torch.load()

1.从本地模型中读取数据

2.加载上一步读取的数据 load_state_dict()

三、torch.optim.Adam

1.optim.Adam()参数说明


一、保存模型-torch.save()

torch.save(parameters, addr)

parameters: 是待保存的权重参数,这个可以是网络的权重参数,也可以是包含多类数据的dict

addr: 是存放数据的地址,相对地址,包括文件全名;如:addr = 'save/model.h5'

1.只保存model的权重

名称.state_dict() 函数——该函数用于获取 网络模型优化器 的权重参数。

示例:

  1. # policy_net 是训练好、待保存的网络
  2. torch.save(policy_net.state_dict(), 'D:/project/model.h5')

2.保存多项内容

以dict变量的形式同时保存多项内容,具体步骤如下:

(1)定义dict变量

dict变量中可以包含多种数据,例如:网络权重、优化器权重、学习率、epsilon、以及其他自定义的参数

(2)调用torch.save()保存变量

torch.save(已定义的dict变量,拟存放地址)

完。

示例:

  1. model = {
  2. 'net':policy_net.state_dict(), # 网络权重
  3. 'opt':optimizer.state_dict(), # 优化器参数
  4. 'eps':epsilon, # 当前epsilon-自定义
  5. 'tsp':total_step, # 当前运行的步数-自定义
  6. }
  7. savedir = 'History/mode_save.h5' # 设定保存模型的地址及文件名
  8. torch.save(model,savedir) # 执行保存

二、加载模型-torch.load()

1.从本地模型中读取数据

checkpoints = torch.load(addr)

其中:

addr              是本文模型的存放地址,包括完整的文件名;

checkpoints 是读取出的数据;

2.加载上一步读取的数据 load_state_dict()

读取出的数据就是上一步定义的checkpoints。

(1)当checkpoints 中只包含网络权重时,可以直接加载,如下:

  1. savedir = 'History/mode_save.h5' # 设定保存模型的地址及文件名
  2. checkpoints = torch.load(savedir) # 从本地读取
  3. policy_net.load_state_dict(checkpoint) # 加载

(2)当checkpoints 中包含多项数据时,根据dict变量的规则,取出数据:

  1. savedir = 'History/mode_save.h5' # 设定保存模型的地址及文件名
  2. checkpoints = torch.load(savedir) # 从本地读取
  3. policy_net.load_state_dict(checkpoint['net']) # 加载网络权重
  4. optimizer.load_state_dict(checkpoint['opt']) # 加载优化器权重
  5. epsilon = checkpoint['eps'] # 取出epsilon值
  6. total_step= checkpoint['tsp'] # 取出total_step值

三、torch.optim.Adam

1.optim.Adam()参数说明

optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

params:            是待优化的参数,一般为网络的权重;

lr:                   是学习率,可不指定,默认值是0.001;

betas:            用于计算梯度以及梯度平方的运行平均值的系数,可不指定,默认值(0.9, 0.999);

eps:                  为了增加数值计算的稳定性而加到分母里的项,可不指定,默认值1e-08;

 weight_decay: 权重衰减L2惩罚,可不指定,默认值 0;

2.优化器建立后修改学习率

优化器是一个列表,学习率存放在{list:0}内,{list:0}是一个dict变量,学习率存放在'lr'内;

示例:

  1. # optimizer是定义好的优化器
  2. optimizer.param_groups[0]['lr'] = 0.999
  3. # 说明:
  4. # optimizer.param_groups是list,优化器的参数全在该项中;
  5. # optimizer.param_groups[0]是list的第一项,该项是dict变量;
  6. # ['lr']是dict变量中存放学习率的编号;

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/103429?site
推荐阅读
相关标签
  

闽ICP备14008679号