当前位置:   article > 正文

加载dict_源码详解Pytorch的state_dict和load_state_dict

destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._ver

Pytorch 中一种模型保存和加载的方式如下:

  1. # save
  2. torch.save(model.state_dict(), PATH)
  3. # load
  4. model = MyModel(*args, **kwargs)
  5. model.load_state_dict(torch.load(PATH))
  6. model.eval()

model.state_dict()其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数,下面看看源代码如何实现的。

state_dict

  1. # torch.nn.modules.module.py
  2. class Module(object):
  3. def state_dict(self, destination=None, prefix='', keep_vars=False):
  4. if destination is None:
  5. destination = OrderedDict()
  6. destination._metadata = OrderedDict()
  7. destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
  8. for name, param in self._parameters.items():
  9. if param is not None:
  10. destination[prefix + name] = param if keep_vars else param.data
  11. for name, buf in self._buffers.items():
  12. if buf is not None:
  13. destination[prefix + name] = buf if keep_vars else buf.data
  14. for name, module in self._modules.items():
  15. if module is not None:
  16. module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
  17. for hook in self._state_dict_hooks.values():
  18. hook_result = hook(self, destination, prefix, local_metadata)
  19. if hook_result is not None:
  20. destination = hook_result
  21. return destination

可以看到state_dict函数中遍历了4中元素,分别是_paramters,_buffers,

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/260396
推荐阅读
相关标签
  

闽ICP备14008679号