当前位置:   article > 正文

pytorch教程之nn.Module类详解——state_dict和parameters两个方法的差异性比较_pytorch.model.state

pytorch.model.state

前言:pytorch的模块Module类有很多的方法,前面的文章中已经介绍了四个常用的方法,这四个方法可以用于获取模块中所定义的对象(即每一个层)他们分别是children()、named_children()、modules()、named_modules()方法,本文介绍另外两个重要的方法,这两个方法会获取到模型中训练的参数(权值矩阵、偏置bias),这两个方法是model.state_dict()方法和model.parameters()方法。前面的文章参考:pytorch教程之nn.Module类详解——使用Module类来自定义模型

一、本文的模型案例

为了简单的演示,本文的模型较为简单,代码如下:

  1. import torch
  2. import torch.nn.functional as F
  3. from torch.optim import SGD
  4. class MyNet(torch.nn.Module):
  5. def __init__(self):
  6. super(MyNet, self).__init__() # 第一句话,调用父类的构造函数
  7. self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
  8. self.relu1=torch.nn.ReLU()
  9. self.max_pooling1=torch.nn.MaxPool2d(2,1)
  10. self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
  11. self.relu2=torch.nn.ReLU()
  12. self.max_pooling2=torch.nn.MaxPool2d(2,1)
  13. self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
  14. self.dense2 = torch.nn.Linear(128, 10)
  15. def forward(self, x):
  16. x = self.conv1(x)
  17. x = self.relu1(x)
  18. x = self.max_pooling1(x)
  19. x = self.conv2(x)
  20. x = self.relu2(x)
  21. x = self.max_pooling2(x)
  22. x = self.dense1(x)
  23. x = self.dense2(x)
  24. return x
  25. model = MyNet() # 构造模型

二、model.state_dict()方法

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

注意:

(1)只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等,像什么池化层、BN层这些本身没有参数的层是没有在这个字典中的;

(2)这个方法的作用一方面是方便查看某一个层的权值和偏置数据,另一方面更多的是在模型保存的时候使用。

2.1 Module的层的权值以及bias查看

  1. print(type(model.state_dict())) # 查看state_dict所返回的类型,是一个“顺序字典OrderedDict”
  2. for param_tensor in model.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
  3. print(param_tensor,'\t',model.state_dict()[param_tensor].size())
  4. '''
  5. conv1.weight torch.Size([32, 3, 3, 3])
  6. conv1.bias torch.Size([32])
  7. conv2.weight torch.Size([32, 3, 3, 3])
  8. conv2.bias torch.Size([32])
  9. dense1.weight torch.Size([128, 288])
  10. dense1.bias torch.Size([128])
  11. dense2.weight torch.Size([10, 128])
  12. dense2.bias torch.Size([10])
  13. '''

当然这里之查看每一个参数的维度信息,如果是查看具体的值也是一样的,跟字典的操作是一样的。

2.2 优化器optimizer的state_dict()方法

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

  1. optimizer = SGD(model.parameters(),lr=0.001,momentum=0.9)
  2. for var_name in optimizer.state_dict():
  3. print(var_name,'\t',optimizer.state_dict()[var_name])
  4. '''
  5. state {}
  6. param_groups [{'lr': 0.001,
  7. 'momentum': 0.9,
  8. 'dampening': 0,
  9. 'weight_decay': 0,
  10. 'nesterov': False,
  11. 'params': [1412966600640, 1412966613064, 1412966613136, 1412966613208,
  12. 1412966613280, 1412966613352, 1412966613496, 1412966613568]
  13. }]
  14. '''

三、model.parameters()方法

这个方法也会获得模型的参数信息,如下:

  1. print(type(model.parameters())) # 返回的是一个generator
  2. for para in model.parameters():
  3. print(para.size()) # 只查看形状
  4. '''
  5. torch.Size([32, 3, 3, 3])
  6. torch.Size([32])
  7. torch.Size([32, 3, 3, 3])
  8. torch.Size([32])
  9. torch.Size([128, 288])
  10. torch.Size([128])
  11. torch.Size([10, 128])
  12. torch.Size([10])
  13. '''

从这里可以看出,其实这个state_dict方法所得到结果差不多,不同的是,model.parameters()方法返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key名称,是一个由纯参数组成的generator,而state_dict是一个字典,包含了一个key

其实Module还有一个与parameters类似的函数,named_parameters,而且parameters正是通过named_parameters来实现的,

看一下parameters的定义,很简单:

  1. def parameters(self, recurse=True):
  2. for name, param in self.named_parameters(recurse=recurse):
  3. yield param

来一起看一下named_parameters的简单使用。

  1. print(type(model.named_parameters())) # 返回的是一个generator
  2. for para in model.named_parameters(): # 返回的每一个元素是一个元组 tuple
  3. '''
  4. 是一个元组 tuple ,元组的第一个元素是参数所对应的名称,第二个元素就是对应的参数值
  5. '''
  6. print(para[0],'\t',para[1].size())
  7. '''
  8. conv1.weight torch.Size([32, 3, 3, 3])
  9. conv1.bias torch.Size([32])
  10. conv2.weight torch.Size([32, 3, 3, 3])
  11. conv2.bias torch.Size([32])
  12. dense1.weight torch.Size([128, 288])
  13. dense1.bias torch.Size([128])
  14. dense2.weight torch.Size([10, 128])
  15. dense2.bias torch.Size([10])
  16. '''

总结:model.state_dict()、model.parameters()、model.named_parameters()这三个方法都可以查看Module的参数信息,用于更新参数,或者用于模型的保存。

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/103487
推荐阅读
相关标签
  

闽ICP备14008679号