当前位置:   article > 正文

pytorch - state_dict() , parameters() 详解

state_dict()

       

目录

1 parameters()

1.1 model.parameters():

1.2 model.named_parameters():

2 state_dict()


        torch.nn.Module 模块中的可学习参数都被包含在该模型的parameters 中,可以通过model.parameters()的方法获取;

        state_dict()是一 个字典,包含了模型各的参数(tensor类型),多用于保存模型;

1 parameters()

1.1 model.parameters():

        源码:

  1. def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
  2. r"""Returns an iterator over module parameters.
  3. 返回模块参数上的迭代器。
  4. This is typically passed to an optimizer.
  5. 这通常被传递给优化器
  6. Args:
  7. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
  8. 如果为True, 则生成该模块 及其所有子模块的参数。否则,只生成该模块的直接成员的形参。
  9. Yields:
  10. Parameter: module parameter
  11. Example::
  12. >>> for param in model.parameters():
  13. >>> print(type(param), param.size())
  14. <class 'torch.Tensor'> (20L,)
  15. <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
  16. """
  17. for name, param in self.named_parameters(recurse=recurse):
  18. yield param

        可以通过Module.parameters()获得网络参数, 迭代的返回模型所有可学习的参数 --  是个生成器

        有些layer不包含可学习的参数,比如(relu, maxpool),因此model.parameters()不会输出这些层;

        parameters()多见于优化器的初始化;

由于parameters()是生成器,因此需要利用循环或者next()来获取数据:

        例子:
 

  1. >>> import torch
  2. >>> import torch.nn as nn
  3. >>> class Net(nn.Module):
  4. ... def __init__(self):
  5. ... super().__init__()
  6. ... self.linear = nn.Linear(2,2)
  7. ... def forward(self,x):
  8. ... out = self.linear(x)
  9. ... return out
  10. ...
  11. >>> net = Net()
  12. >>> for para in net.parameters():
  13. ... print(para)
  14. ...
  15. Parameter containing:
  16. tensor([[-0.1954, -0.2290],
  17. [ 0.5897, -0.3970]], requires_grad=True)
  18. Parameter containing:
  19. tensor([-0.1808, 0.2044], requires_grad=True)
  20. >>> for para in net.named_parameters():
  21. ... print(para)
  22. ...
  23. ('linear.weight', Parameter containing:
  24. tensor([[-0.1954, -0.2290],
  25. [ 0.5897, -0.3970]], requires_grad=True))
  26. ('linear.bias', Parameter containing:
  27. tensor([-0.1808, 0.2044], requires_grad=True))

1.2 model.named_parameters():

        是带有layer name的model.parameters(),其以tuple方式输出,其中包含两个元素,分别为layer name和 model.parameters;

        layer name有后缀 .weight, .bias用于区分权重和偏置;

源码:

  1. def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
  2. r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
  3. 返回模块参数上的迭代器,生成参数名和参数本身。
  4. Args:
  5. prefix (str): prefix to prepend to all parameter names.
  6. recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
  7. 如果为True,则生成该模块及其所有子模块的参数。否则,只生成该模块的直接成员的形参。
  8. Yields:
  9. (string, Parameter): Tuple containing the name and parameter
  10. Example::
  11. >>> for name, param in self.named_parameters():
  12. >>> if name in ['bias']:
  13. >>> print(param.size())
  14. """
  15. gen = self._named_members(
  16. lambda module: module._parameters.items(),
  17. prefix=prefix, recurse=recurse)
  18. for elem in gen:
  19. yield elem

代码例子,看1.1部分;

2 state_dict()

        model.state_dict()能够获得模型所有的参数,包括可学习的参数和不可学习的参数,返回值是一个有序字典OrderedDict.

这部分相当于在model.parameters()基础上,又额外获取了不可学习的参数部分;

例子:

        key值是对网络参数的说明,这里是线性层的weight和bias;

  1. >>> class Net(nn.Module):
  2. ... def __init__(self):
  3. ... super().__init__()
  4. ... self.linear = nn.Linear(10,8)
  5. ... self.dropout = nn.Dropout(0.5)
  6. ... self.linear1 = nn.Linear(8,2)
  7. ... def forward(self,x):
  8. ... out = self.dropout(self.linear(x))
  9. ... out = self.linear1(out)
  10. ... return out
  11. ...
  12. >>> net = Net()
  13. >>> net.state_dict()
  14. OrderedDict([('linear.weight', tensor([[ 0.1415, -0.2228, -0.1262, 0.0992, -0.1600, 0.0141, -0.1841, -0.1907,
  15. 0.0295, -0.1853],
  16. [-0.0399, -0.2487, -0.3085, 0.1602, 0.3135, 0.1379, 0.0696, 0.0362,
  17. -0.1619, -0.0887],
  18. [-0.1244, -0.1739, 0.1211, -0.2578, -0.0561, 0.0635, -0.1976, -0.2557,
  19. 0.1761, 0.2553],
  20. [ 0.0912, -0.1469, -0.3012, -0.1583, -0.0028, 0.2697, 0.1947, -0.0596,
  21. -0.2144, -0.0785],
  22. [-0.1770, 0.0411, 0.1663, 0.1861, 0.2769, 0.0990, 0.1883, -0.1801,
  23. 0.2727, 0.1219],
  24. [-0.1269, 0.0713, 0.2798, 0.1760, 0.0965, 0.1144, 0.2644, 0.0274,
  25. 0.0034, 0.2702],
  26. [ 0.0628, 0.0682, -0.1842, 0.1461, 0.0678, -0.2264, -0.1249, -0.1715,
  27. 0.1115, 0.2459],
  28. [ 0.1198, -0.2584, 0.0234, 0.2756, 0.1174, -0.1212, 0.3024, -0.2304,
  29. -0.2950, 0.0970]])), ('linear.bias', tensor([-0.3036, -0.1933, 0.2412, 0.3137, -0.3007, 0.2386, -0.1975, 0.3127])), ('linear1.weight', tensor([[-0.1725, 0.3027, 0.1985, 0.1394, -0.1245, 0.2913, 0.0136, 0.1633],
  30. [-0.1558, -0.0865, -0.3032, 0.1374, 0.2967, -0.2886, 0.0430, -0.1246]])), ('linear1.bias', tensor([-0.1232, -0.0690]))])
  31. >>>

参考:PyTorch中model.state_dict(),model.modules(),model.children(),model.named_children()等含义_yaoyz105的博客-CSDN博客_model.state_dict()

model.parameters()与model.state_dict() - 知乎 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/260385?site
推荐阅读
相关标签
  

闽ICP备14008679号