赞
踩
目录
2.加载上一步读取的数据 load_state_dict()
torch.save(parameters, addr)
parameters: 是待保存的权重参数,这个可以是网络的权重参数,也可以是包含多类数据的dict;
addr: 是存放数据的地址,相对地址,包括文件全名;如:addr = 'save/model.h5'
名称.state_dict() 函数——该函数用于获取 网络模型 或 优化器 的权重参数。
示例:
- # policy_net 是训练好、待保存的网络
- torch.save(policy_net.state_dict(), 'D:/project/model.h5')
以dict变量的形式同时保存多项内容,具体步骤如下:
(1)定义dict变量
dict变量中可以包含多种数据,例如:网络权重、优化器权重、学习率、epsilon、以及其他自定义的参数。
(2)调用torch.save()保存变量
torch.save(已定义的dict变量,拟存放地址)
完。
示例:
- model = {
- 'net':policy_net.state_dict(), # 网络权重
- 'opt':optimizer.state_dict(), # 优化器参数
- 'eps':epsilon, # 当前epsilon-自定义
- 'tsp':total_step, # 当前运行的步数-自定义
- }
- savedir = 'History/mode_save.h5' # 设定保存模型的地址及文件名
- torch.save(model,savedir) # 执行保存
checkpoints = torch.load(addr)
其中:
addr 是本文模型的存放地址,包括完整的文件名;
checkpoints 是读取出的数据;
读取出的数据就是上一步定义的checkpoints。
(1)当checkpoints 中只包含网络权重时,可以直接加载,如下:
- savedir = 'History/mode_save.h5' # 设定保存模型的地址及文件名
- checkpoints = torch.load(savedir) # 从本地读取
- policy_net.load_state_dict(checkpoint) # 加载
(2)当checkpoints 中包含多项数据时,根据dict变量的规则,取出数据:
- savedir = 'History/mode_save.h5' # 设定保存模型的地址及文件名
- checkpoints = torch.load(savedir) # 从本地读取
-
- policy_net.load_state_dict(checkpoint['net']) # 加载网络权重
- optimizer.load_state_dict(checkpoint['opt']) # 加载优化器权重
- epsilon = checkpoint['eps'] # 取出epsilon值
- total_step= checkpoint['tsp'] # 取出total_step值
optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
params: 是待优化的参数,一般为网络的权重;
lr: 是学习率,可不指定,默认值是0.001;
betas: 用于计算梯度以及梯度平方的运行平均值的系数,可不指定,默认值(0.9, 0.999);
eps: 为了增加数值计算的稳定性而加到分母里的项,可不指定,默认值1e-08;
weight_decay: 权重衰减L2惩罚,可不指定,默认值 0;
2.优化器建立后修改学习率
优化器是一个列表,学习率存放在{list:0}内,{list:0}是一个dict变量,学习率存放在'lr'内;
示例:
- # optimizer是定义好的优化器
- optimizer.param_groups[0]['lr'] = 0.999
-
- # 说明:
- # optimizer.param_groups是list,优化器的参数全在该项中;
- # optimizer.param_groups[0]是list的第一项,该项是dict变量;
- # ['lr']是dict变量中存放学习率的编号;
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。