当前位置:   article > 正文

Pytorch加载模型权重理解(state_dict load_state_dict update load)_unet.load_state_dict 函数作用

unet.load_state_dict 函数作用

Pytorch加载模型权重理解(state_dict load_state_dict update load)

一、state_dict特性介绍

在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm’s running_mean,关于batchnorm详解可见https://blog.csdn.net/wzy_zju/article/details/81262453

torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。

因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。

二、torch.load函数特性介绍

作用:用来加载torch.save() 保存的模型文件。

使用方式

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args)
  • 1

参数解释:

f:权重文件地址

map_location:设备 CPU还是GPU

后两个参数可以不用管

三、update函数使用方法

d1.update(d2)的作用是,将字典d2的内容合并到d1中,

其中d2中的键值对但d1中没有的键值对会增加到d1中去,

两者都有的键值对更新为d2的键值对.

d1 = {"浙江":"杭州","江苏":"nanjing"}
d1
{'浙江': '杭州', '江苏': 'nanjing'}
d1.update(江苏="南京")
d1
{'浙江': '杭州', '江苏': '南京'}
d2 = {"山东":"济南","河北":"石家庄"}
d1
{'浙江': '杭州', '江苏': '南京'}
d1.update(d2)
d1
{'浙江': '杭州', '江苏': '南京', '山东': '济南', '河北': '石家庄'}
d3 = {"浙江":"杭州市*****"}
d1
{'浙江': '杭州', '江苏': '南京', '山东': '济南', '河北': '石家庄'}
d1.update(d3)
d1
{'浙江': '杭州市*****', '江苏': '南京', '山东': '济南', '河北': '石家庄'}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

四、items函数的用法

Python 字典(Dictionary) items() 函数以列表返回可遍历的(键, 值) 元组数组。

items() 方法把字典中每对 key 和 value 组成一个元组,并把这些元组放在列表中返回

五、load_state_dict函数用法

与state_dict相比,我理解的是,load_state_dict是更新好的权重放回去,state_dict是将权重系数取出来。

六、综合分析

权重的预加载可以综合到这几步

if G_model_path != '':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model_dict = G_model.state_dict()
        pretrained_dict = torch.load(G_model_path, map_location=device)
        pretrained_dict = {k : v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
        model_dict.update(pretrained_dict)
        G_model.load_state_dict(model_dict)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

第一步:获取当前设备到底是GPU还是CPU

第二步:取出当前还未加载权重的字典

第三步:用torch.load获取与训练好的新的权重的字典

第四步:在第三步的字典中,判断第三步新的权重和原始模型的权重的大小shape是否一致

​ 如果一致,新的权重字典就保留这个(键、值)权重

​ 如果不一致,新的权重字典就舍去这个(键、值)权重

第五步:用第四步最新的 字典 来更新第二步的原始字典

第六步:用第五步更新后的权重字典放回模型中。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/406953
推荐阅读
相关标签
  

闽ICP备14008679号