赞
踩
mxnet中 gluon.Trainer()是注册优化器的函数
trainer=gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':lr})
在with autograd.record():之后结合trainer.step(batch_size)更新网络权值
gluon.Trainer()的参数为:
params:net.collect_params(),一个类似于字典的类型
optimizer:字符串(优化器名称)
optimiter_parameter:字典类型(用于设置优化器参数{‘learning_rate’:lr,‘wd’:wd}
1、Trainer中首相将params装换为list类型->_params,之后用于trainer.step(batch_size)->调用内部函数_update()更新_params的权值
2、初始化优化器_init_optimizer(optimizer, optimizer_params),将_params转换为{i:param}的字典param_dict,然后调用
- import optimizer as opt
- self._optimizer = opt.create(optimizer, param_dict = param_dict, **optimizer_params)
3、create调用create_optimizer(name, **kwarg)返回与name相对应的优化器
4、在import optimizer as opt时其实已经运行了register函数,将所有的优化器类注册进了opt_registry字典,优化器类都被@register修饰,import时使用register函数进行了注册
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。