赞
踩
需要注意的点:
1、用SummaryWriter中的add_graph功能可以可视化神经网络结构
2、nn.Flatten()是类,可以和各层一起建立,本版本Pycharm无法正常联想
采用了用sequential和不用sequential两种方式搭建:
- # 采用传统方式和nn.sequential方式建立一个简单卷积网络,从而说明sequential的作用。
- import torch
- from torch import nn
-
-
- # 方法一:不采用sequential建立网络
- from torch.utils.tensorboard import SummaryWriter
-
-
- class My_nn(nn.Module):
- def __init__(self):
- super(My_nn, self).__init__()
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)
- self.maxpool1 = nn.MaxPool2d(2)
- self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
- self.maxpool2 = nn.MaxPool2d(2)
- self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
- self.maxpool3 = nn.MaxPool2d(2)
- self.flatten = nn.Flatten()
- self.linear1 = nn.Linear(1024, 64)
- self.linear2 = nn.Linear(64, 10)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.maxpool1(x)
- x = self.conv2(x)
- x = self.maxpool2(x)
- x = self.conv3(x)
- x = self.maxpool3(x)
- x = self.flatten(x)
- x = self.linear1(x)
- x = self.linear2(x)
- return x
-
- class My_nn_s(nn.Module):
- def __init__(self):
- super(My_nn_s, self).__init__()
- self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
- nn.MaxPool2d(2),
- nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
- nn.MaxPool2d(2),
- nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
- nn.MaxPool2d(2),
- nn.Flatten(),
- nn.Linear(1024, 64),
- nn.Linear(64, 10)
- )
-
-
- def forward(self, x):
- x = self.model(x)
- return x
-
- if __name__ == '__main__':
- input_t = torch.rand([64, 3, 32, 32]) # 初始化输入
- print(input_t.shape) # torch.Size([64, 3, 32, 32])
-
- # 实例化类My_nn
- my_nn = My_nn()
- out_put = my_nn(input_t)
- print(out_put.shape) # torch.Size([64, 10])
- print(my_nn) # 显示网络结构
-
- # 实例化类My_nn_s
- my_nn_s = My_nn_s()
- out_put = my_nn_s(input_t)
- print(out_put.shape) # torch.Size([64, 10])
- print(my_nn_s) # 显示网络结构
-
- # 可视化类My_nn_s
- writer = SummaryWriter('nn_seq')
- writer.add_graph(my_nn_s, input_t)
- writer.close() # 不关闭没法显示
-
-
-
-
-
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。