当前位置:   article > 正文

用torch.nn.Sequential()搭建神经网络模型_torch sequential

torch sequential

原始定义方式与 nn.Sequential 两种定义方式实例:

可以看到使用torch.nn.Sequential()搭建神经网络模型非常的方便,少写很多的code

  1. import torch
  2. import torch.nn as nn
  3. # -------------------------方式一:传统网络定义方式--------------------------------
  4. class Net(nn.Module):
  5. def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
  6. super(Net, self).__init__()
  7. self.linear1 = nn.Linear(in_dim, n_hidden_1)
  8. self.Relu1 = nn.ReLU(True)
  9. self.linear2 = nn.Linear(n_hidden_1, n_hidden_2)
  10. self.Relu2 = nn.ReLU(True)
  11. self.linear3 = nn.Linear(n_hidden_2, out_dim)
  12. def forward(self, x):
  13. x = self.linear1(x)
  14. x = self.Relu1(x)
  15. x = self.linear2(x)
  16. x = self.Relu2(x)
  17. x = self.linear3(x)
  18. return x
  19. # -------------------------方式二:使用nn.Sequential定义网络------------------------
  20. class Net(nn.Module):
  21. def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
  22. super(Net, self).__init__()
  23. self.layer = nn.Sequential(
  24. nn.Linear(in_dim, n_hidden_1), # (18,15)
  25. nn.ReLU(True),
  26. nn.Linear(n_hidden_1, n_hidden_2), # (15,10)
  27. nn.ReLU(True),
  28. nn.Linear(n_hidden_2, out_dim) # (10,1)
  29. )
  30. def forward(self, x):
  31. x = self.layer(x)
  32. return x
  33. # instantiation
  34. net = Net(18, 15, 10, 1)
  35. # create random input to model
  36. input = torch.randn(30, 18)
  37. # output the predicted value
  38. predict = net(input)
  39. print(predict.size())
  40. print(net)


torch.nn.Sequential是一个Sequential容器,模块将按照构造函数中传递的顺序添加到模块中。通俗的话说,就是根据自己的需求,把不同的函数组合成一个(小的)模块使用或者把组合的模块添加到自己的网络中

一、第一种方式(可以配合一些条件判断语句动态添加)

  • 模板——torch.nn.Sequential()的一个对象.add_module(name, module)。
  • name:某层次的名字;module:需要添加的子模块,如卷积、激活函数等等。
  • 添加子模块到当前模块中。
  • 可以通过 name 属性来访问添加的子模块。
  • 输出后每一层的名字:不是采用默认的命名方式(按序号 0,1,2,3…),而是按照name属性命名!!
  1. import torch.nn as nn
  2. model = nn.Sequential()
  3. model.add_module("conv1", nn.Conv2d(1, 20, 5))
  4. model.add_module('relu1', nn.ReLU())
  5. model.add_module('conv2', nn.Conv2d(20, 64, 5))
  6. model.add_module('relu2', nn.ReLU())
  7. # 输出
  8. Sequential(
  9. (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  10. (relu1): ReLU()
  11. (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  12. (relu2): ReLU()
  13. )

注意!!!nn.module也有add_module()对象

  1. # 被添加的module可以通过 name 属性来获取。
  2. import torch.nn as nn
  3. class Model(nn.Module):
  4. def __init__(self):
  5. super(Model, self).__init__()
  6. self.add_module("conv", nn.Conv2d(10, 20, 4))
  7. # self.conv = nn.Conv2d(10, 20, 4) 和上面这个增加module的方式等价
  8. model = Model()
  9. print(model.conv) # 通过name属性访问添加的子模块
  10. print(model)
  11. # 输出:注意子模块的命名方式
  12. Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
  13. Model(
  14. (conv): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
  15. )

二、第二种方式

  • 模板——nn.Sequential(*module)
  • 输出的每一层的名字:采用默认的命名方式(按序号 0,1,2,3…)
  1. import torch.nn as nn
  2. model = nn.Sequential(
  3. nn.Conv2d(1,20,5),
  4. nn.ReLU(),
  5. nn.Conv2d(20,64,5),
  6. nn.ReLU()
  7. )
  8. print(model)
  9. # 输出:注意命名方式
  10. Sequential(
  11. (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  12. (1): ReLU()
  13. (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  14. (3): ReLU()
  15. )

三、第三种方式

  • 模板——nn.Sequential(OrderedDict([*(name, module)]))
  • 输出后每一层的名字:不是采用默认的命名方式(按序号 0,1,2,3…),而是按照name属性命名!!
  1. import collections
  2. import torch.nn as nn
  3. model = nn.Sequential(collections.OrderedDict([('conv1', nn.Conv2d(1, 20, 5)), ('relu1', nn.ReLU()),
  4. ('conv2', nn.Conv2d(20, 64, 5)),
  5. ('relu2', nn.ReLU())
  6. ]))
  7. print(model)
  8. # 输出:注意子模块命名方式
  9. Sequential(
  10. (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  11. (relu1): ReLU()
  12. (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  13. (relu2): ReLU()
  14. )

Pytorch系列1: torch.nn.Sequential()讲解_xddwz的博客-CSDN博客_torch.nn.sequential

【深度学习笔记】用torch.nn.Sequential()搭建神经网络模型_Murphy.AI 的文章-CSDN博客

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号