当前位置:   article > 正文

pytorch中用nn.sequential搭建简单网络_pytorch sequential直接组网

pytorch sequential直接组网

需要注意的点:

1、用SummaryWriter中的add_graph功能可以可视化神经网络结构

2、nn.Flatten()是类,可以和各层一起建立,本版本Pycharm无法正常联想

采用了用sequential和不用sequential两种方式搭建:

  1. # 采用传统方式和nn.sequential方式建立一个简单卷积网络,从而说明sequential的作用。
  2. import torch
  3. from torch import nn
  4. # 方法一:不采用sequential建立网络
  5. from torch.utils.tensorboard import SummaryWriter
  6. class My_nn(nn.Module):
  7. def __init__(self):
  8. super(My_nn, self).__init__()
  9. self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)
  10. self.maxpool1 = nn.MaxPool2d(2)
  11. self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
  12. self.maxpool2 = nn.MaxPool2d(2)
  13. self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
  14. self.maxpool3 = nn.MaxPool2d(2)
  15. self.flatten = nn.Flatten()
  16. self.linear1 = nn.Linear(1024, 64)
  17. self.linear2 = nn.Linear(64, 10)
  18. def forward(self, x):
  19. x = self.conv1(x)
  20. x = self.maxpool1(x)
  21. x = self.conv2(x)
  22. x = self.maxpool2(x)
  23. x = self.conv3(x)
  24. x = self.maxpool3(x)
  25. x = self.flatten(x)
  26. x = self.linear1(x)
  27. x = self.linear2(x)
  28. return x
  29. class My_nn_s(nn.Module):
  30. def __init__(self):
  31. super(My_nn_s, self).__init__()
  32. self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
  33. nn.MaxPool2d(2),
  34. nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
  35. nn.MaxPool2d(2),
  36. nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
  37. nn.MaxPool2d(2),
  38. nn.Flatten(),
  39. nn.Linear(1024, 64),
  40. nn.Linear(64, 10)
  41. )
  42. def forward(self, x):
  43. x = self.model(x)
  44. return x
  45. if __name__ == '__main__':
  46. input_t = torch.rand([64, 3, 32, 32]) # 初始化输入
  47. print(input_t.shape) # torch.Size([64, 3, 32, 32])
  48. # 实例化类My_nn
  49. my_nn = My_nn()
  50. out_put = my_nn(input_t)
  51. print(out_put.shape) # torch.Size([64, 10])
  52. print(my_nn) # 显示网络结构
  53. # 实例化类My_nn_s
  54. my_nn_s = My_nn_s()
  55. out_put = my_nn_s(input_t)
  56. print(out_put.shape) # torch.Size([64, 10])
  57. print(my_nn_s) # 显示网络结构
  58. # 可视化类My_nn_s
  59. writer = SummaryWriter('nn_seq')
  60. writer.add_graph(my_nn_s, input_t)
  61. writer.close() # 不关闭没法显示

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/537878
推荐阅读
相关标签
  

闽ICP备14008679号