赞
踩
目录
torch.nn.Module 模块中的可学习参数都被包含在该模型的parameters 中,可以通过model.parameters()的方法获取;
state_dict()是一 个字典,包含了模型各的参数(tensor类型),多用于保存模型;
源码:
- def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
- r"""Returns an iterator over module parameters.
- 返回模块参数上的迭代器。
- This is typically passed to an optimizer.
- 这通常被传递给优化器
- Args:
- recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- 如果为True, 则生成该模块 及其所有子模块的参数。否则,只生成该模块的直接成员的形参。
- Yields:
- Parameter: module parameter
- Example::
- >>> for param in model.parameters():
- >>> print(type(param), param.size())
- <class 'torch.Tensor'> (20L,)
- <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- """
- for name, param in self.named_parameters(recurse=recurse):
- yield param
可以通过Module.parameters()获得网络参数, 迭代的返回模型所有可学习的参数 -- 是个生成器;
有些layer不包含可学习的参数,比如(relu, maxpool),因此model.parameters()不会输出这些层;
parameters()多见于优化器的初始化;
由于parameters()是生成器,因此需要利用循环或者next()来获取数据:
例子:
- >>> import torch
- >>> import torch.nn as nn
-
- >>> class Net(nn.Module):
- ... def __init__(self):
- ... super().__init__()
- ... self.linear = nn.Linear(2,2)
- ... def forward(self,x):
- ... out = self.linear(x)
- ... return out
- ...
- >>> net = Net()
- >>> for para in net.parameters():
- ... print(para)
- ...
-
- Parameter containing:
- tensor([[-0.1954, -0.2290],
- [ 0.5897, -0.3970]], requires_grad=True)
- Parameter containing:
- tensor([-0.1808, 0.2044], requires_grad=True)
-
- >>> for para in net.named_parameters():
- ... print(para)
- ...
- ('linear.weight', Parameter containing:
- tensor([[-0.1954, -0.2290],
- [ 0.5897, -0.3970]], requires_grad=True))
- ('linear.bias', Parameter containing:
- tensor([-0.1808, 0.2044], requires_grad=True))
是带有layer name的model.parameters(),其以tuple方式输出,其中包含两个元素,分别为layer name和 model.parameters;
layer name有后缀 .weight, .bias用于区分权重和偏置;
源码:
- def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
- r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- 返回模块参数上的迭代器,生成参数名和参数本身。
- Args:
- prefix (str): prefix to prepend to all parameter names.
- recurse (bool): if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- 如果为True,则生成该模块及其所有子模块的参数。否则,只生成该模块的直接成员的形参。
- Yields:
- (string, Parameter): Tuple containing the name and parameter
- Example::
- >>> for name, param in self.named_parameters():
- >>> if name in ['bias']:
- >>> print(param.size())
- """
- gen = self._named_members(
- lambda module: module._parameters.items(),
- prefix=prefix, recurse=recurse)
- for elem in gen:
- yield elem
代码例子,看1.1部分;
model.state_dict()能够获得模型所有的参数,包括可学习的参数和不可学习的参数,返回值是一个有序字典OrderedDict.
这部分相当于在model.parameters()基础上,又额外获取了不可学习的参数部分;
例子:
key值是对网络参数的说明,这里是线性层的weight和bias;
- >>> class Net(nn.Module):
- ... def __init__(self):
- ... super().__init__()
- ... self.linear = nn.Linear(10,8)
- ... self.dropout = nn.Dropout(0.5)
- ... self.linear1 = nn.Linear(8,2)
- ... def forward(self,x):
- ... out = self.dropout(self.linear(x))
- ... out = self.linear1(out)
- ... return out
- ...
- >>> net = Net()
- >>> net.state_dict()
- OrderedDict([('linear.weight', tensor([[ 0.1415, -0.2228, -0.1262, 0.0992, -0.1600, 0.0141, -0.1841, -0.1907,
- 0.0295, -0.1853],
- [-0.0399, -0.2487, -0.3085, 0.1602, 0.3135, 0.1379, 0.0696, 0.0362,
- -0.1619, -0.0887],
- [-0.1244, -0.1739, 0.1211, -0.2578, -0.0561, 0.0635, -0.1976, -0.2557,
- 0.1761, 0.2553],
- [ 0.0912, -0.1469, -0.3012, -0.1583, -0.0028, 0.2697, 0.1947, -0.0596,
- -0.2144, -0.0785],
- [-0.1770, 0.0411, 0.1663, 0.1861, 0.2769, 0.0990, 0.1883, -0.1801,
- 0.2727, 0.1219],
- [-0.1269, 0.0713, 0.2798, 0.1760, 0.0965, 0.1144, 0.2644, 0.0274,
- 0.0034, 0.2702],
- [ 0.0628, 0.0682, -0.1842, 0.1461, 0.0678, -0.2264, -0.1249, -0.1715,
- 0.1115, 0.2459],
- [ 0.1198, -0.2584, 0.0234, 0.2756, 0.1174, -0.1212, 0.3024, -0.2304,
- -0.2950, 0.0970]])), ('linear.bias', tensor([-0.3036, -0.1933, 0.2412, 0.3137, -0.3007, 0.2386, -0.1975, 0.3127])), ('linear1.weight', tensor([[-0.1725, 0.3027, 0.1985, 0.1394, -0.1245, 0.2913, 0.0136, 0.1633],
- [-0.1558, -0.0865, -0.3032, 0.1374, 0.2967, -0.2886, 0.0430, -0.1246]])), ('linear1.bias', tensor([-0.1232, -0.0690]))])
- >>>
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。