当前位置:   article > 正文

nn.Sequential、nn.ModuleList、nn.ModuleDict区别及使用技巧

nn.moduledict

目录

一、区别及联系

二、使用技巧

2.1、nn.Sequential()

2.2、nn.ModuleList()

2.3、nn.ModuleDict() 


一、区别及联系

 先通过图片总结了解三个容器方法的主要区别:

  1. nn.Sequential容器自带forward()方法,无需显示调用。nn.ModuList和nn.ModuleDict自身不具有forward()方法。
  2. nn.Sequential内的网络层必须顺序执行,上一层的输出必须与下一层的输入大小一致。
  3. nn.ModuleDict和nn.ModuleList容器内的网络层无需按顺序执行。

二、使用技巧

2.1、nn.Sequential()

可以直接添加网络层、也可以先声明后利用add_module(name:str,module)方法添加网络层,还可以使用OrderDict([*(name:str,module)])函数添加。

  1. net1 = nn.Sequential(
  2. nn.Conv2d(3,6,kernel_size=5),
  3. nn.Conv2d(6,10,kernel_size=3),
  4. nn.BatchNorm2d(10),
  5. nn.ReLU(),
  6. )
  7. net2 = nn.Sequential()
  8. net2.add_module('conv1',nn.Conv2d(3,6,kernel_size=5))
  9. net2.add_module('conv2',nn.Conv2d(6,10,kernel_size=3))
  10. net2.add_module('bn',nn.BatchNorm2d(10))
  11. net2.add_module('relu',nn.ReLU())
  12. net3 = nn.Sequential(OrderedDict([
  13. ['conv1',nn.Conv2d(3,6,kernel_size=5)],
  14. ('conv2',nn.Conv2d(6,10,kernel_size=3))
  15. ]))
  16. print('#####################')
  17. print(net1)
  18. print('#####################')
  19. print(net2)
  20. print('#####################')
  21. print(net3)

输出结果为

  1. #####################
  2. Sequential(
  3. (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  4. (1): Conv2d(6, 10, kernel_size=(3, 3), stride=(1, 1))
  5. (2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  6. (3): ReLU()
  7. )
  8. #####################
  9. Sequential(
  10. (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  11. (conv2): Conv2d(6, 10, kernel_size=(3, 3), stride=(1, 1))
  12. (bn): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  13. (relu): ReLU()
  14. )
  15. #####################
  16. Sequential(
  17. (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  18. (conv2): Conv2d(6, 10, kernel_size=(3, 3), stride=(1, 1))
  19. )

2.2、nn.ModuleList()

nn.ModuleList里面储存了不同 module,并自动将每个 module 的 parameters 添加到网络容器内容(注册),里面的module是按照List的形式顺序存储的,但是在forward中调用的时候可以随意组合。可以任意将 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,也就是说它可以使用 extend,append 等操作。
 

  1. model = nn.ModuleList([
  2. nn.Conv2d(3, 6, kernel_size=5),
  3. nn.Conv2d(6, 10, kernel_size=3),
  4. nn.BatchNorm2d(10),
  5. nn.ReLU(),
  6. ])
  7. model.extend([nn.Linear(10,10) for i in range(5)])
  8. print(model)

 输出结果为:

  1. ModuleList(
  2. (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  3. (1): Conv2d(6, 10, kernel_size=(3, 3), stride=(1, 1))
  4. (2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  5. (3): ReLU()
  6. (4): Linear(in_features=10, out_features=10, bias=True)
  7. (5): Linear(in_features=10, out_features=10, bias=True)
  8. (6): Linear(in_features=10, out_features=10, bias=True)
  9. (7): Linear(in_features=10, out_features=10, bias=True)
  10. (8): Linear(in_features=10, out_features=10, bias=True)
  11. )

运行模块可以直接使用列表索引方式或者利用for循环调用,但是顺序不固定

  1. input = torch.randn(1,6,3,3)
  2. out = model[1](input)
  3. print(out.shape)
  4. #view():[1,10,1,1]->[1,10]
  5. out = out.view(out.shape[0],out.shape[1])
  6. out = [model[i](out) for i in range(4,7)]
  7. for o in out:
  8. print(o.shape)
  9. ######################
  10. torch.Size([1, 10, 1, 1])
  11. torch.Size([1, 10])
  12. torch.Size([1, 10])
  13. torch.Size([1, 10])

2.3、nn.ModuleDict() 

nn.ModuleDict书写格式也分为两种:一种是nn.ModuleDict( {name:module , name:module ,...} ),另一种是nn.ModuleDict([ [name,module] , [name,module], ... ])

  1. class MyNet(nn.Module):
  2. def __init__(self):
  3. super(MyNet, self).__init__()
  4. self.choices = nn.ModuleDict({
  5. 'conv': nn.Conv2d(10, 10, 3),
  6. 'pool': nn.MaxPool2d(3)
  7. })
  8. self.activations = nn.ModuleDict([
  9. ['lrelu', nn.LeakyReLU()],
  10. ['prelu', nn.PReLU()]
  11. ])
  12. def forward(self, x, choice, act):
  13. # x = self.choices[choice](x)
  14. # x = self.activations[act](x)
  15. return x
  16. net = MyNet()
  17. print(net)

输出结果为

  1. MyNet(
  2. (choices): ModuleDict(
  3. (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  4. (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  5. )
  6. (activations): ModuleDict(
  7. (lrelu): LeakyReLU(negative_slope=0.01)
  8. (prelu): PReLU(num_parameters=1)
  9. )
  10. )

三、参考文献

PyTorch中的Sequential、ModuleList和ModuleDict用法总结_非晚非晚的博客-CSDN博客

nn.Sequential与nn.ModuleList_HySmiley的博客-CSDN博客

pytorch模型容器Containers nn.ModuleDict、nn.moduleList、nn.Sequential_nn.moduledict()_发呆的比目鱼的博客-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/653104
推荐阅读
相关标签
  

闽ICP备14008679号