赞
踩
在 Pytorch 中一种模型保存和加载的方式如下:
- # save
- torch.save(model.state_dict(), PATH)
-
- # load
- model = MyModel(*args, **kwargs)
- model.load_state_dict(torch.load(PATH))
- model.eval()
model.state_dict()
其实返回的是一个OrderDict
,存储了网络结构的名字和对应的参数,下面看看源代码如何实现的。
- # torch.nn.modules.module.py
- class Module(object):
- def state_dict(self, destination=None, prefix='', keep_vars=False):
- if destination is None:
- destination = OrderedDict()
- destination._metadata = OrderedDict()
- destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
- for name, param in self._parameters.items():
- if param is not None:
- destination[prefix + name] = param if keep_vars else param.data
- for name, buf in self._buffers.items():
- if buf is not None:
- destination[prefix + name] = buf if keep_vars else buf.data
- for name, module in self._modules.items():
- if module is not None:
- module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
- for hook in self._state_dict_hooks.values():
- hook_result = hook(self, destination, prefix, local_metadata)
- if hook_result is not None:
- destination = hook_result
- return destination
可以看到state_dict函数中遍历了4中元素,分别是_paramters
,_buffers
,
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。