赞
踩
当按照该链接中的命令无法完全可视化pytorch定义的模型时(比如信息显示不全,没有output,没有按照模型执行的逻辑顺序)
使用如下方法:
需要定义输入数据,用1填充就可以: input_data = torch.ones(1, 1, 28, 28),之后才能显示输出的可视化效果,类似于keras的可视化
class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5) # 输入1 输出10 self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=3) # 输入10 输出20 self.conv3 = torch.nn.Conv2d(20, 10, kernel_size=3) # 输入20 输出10 self.pooling = torch.nn.MaxPool2d(2) self.pooling = torch.nn.MaxPool2d(2) self.flatten = torch.nn.Flatten() self.fc1 = torch.nn.Linear(90, 64) # 输入90 输出64 self.fc2 = torch.nn.Linear(64, 32) self.fc3 = torch.nn.Linear(32, 10) def forward(self, x): batch_size = x.size(0) x = F.relu(self.pooling(self.conv1(x))) # (N,1,28,28)->(N,10,24,24)->(N,10,12,12) x = F.relu(self.pooling(self.conv2(x))) # (N,10,12,12)->(N,20,10,10)->(N,20,5,5) x = F.relu(self.conv3(x)) # (N,20,5,5)->(N,10,3,3) x = self.flatten(x) # (N,10,3,3)->(N,90) x = self.fc1(x) x = self.fc2(x) x = self.fc3(x) return x # 需要将input_data定义为torch.ones(1, 1, 28, 28) # (1, 1, 28, 28) 表示 (样本数量, 通道数, 长, 宽) # 样本数量可以随便取,torch.ones from torchsummary import summary model = Net() modelVisual = summary(model, input_data = torch.ones(1, 1, 28, 28), device='cpu') # device默认是cuda
输出结果如下:
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Conv2d: 1-1 [-1, 10, 24, 24] 260 ├─MaxPool2d: 1-2 [-1, 10, 12, 12] -- ├─Conv2d: 1-3 [-1, 20, 10, 10] 1,820 ├─MaxPool2d: 1-4 [-1, 20, 5, 5] -- ├─Conv2d: 1-5 [-1, 10, 3, 3] 1,810 ├─Flatten: 1-6 [-1, 90] -- ├─Linear: 1-7 [-1, 64] 5,824 ├─Linear: 1-8 [-1, 32] 2,080 ├─Linear: 1-9 [-1, 10] 330 ========================================================================================== Total params: 12,124 Trainable params: 12,124 Non-trainable params: 0 Total mult-adds (M): 0.35 ========================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 0.06 Params size (MB): 0.05 Estimated Total Size (MB): 0.11 ==========================================================================================
pytorch版本
torch.__version__
'1.12.0+cu113'
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。