赞
踩
使用 PyTorch 深度学习搭建模型后,如果想查看模型结构,可以直接使用 print(model) 函数打印。但该输出结果不是特别直观,查阅发现有个能输出类似 keras 风格 model.summary() 的模型可视化工具。这里记录一下方便以后查阅。
pip install torchsummary
summary函数介绍
model
:网络模型
input_size
:网络输入图片的shape,这里不用加batch_size进去
batch_size
:batch_size参数,默认是-1
device
:在GPU还是CPU上运行,默认是cuda在GPU上运行,如果想在CPU上执行将参数改为CPU即可
import torch import torch.nn as nn from torchsummary import summary class Shallow_ConvNet(nn.Module): def __init__(self, in_channel, conv_channel_temp, kernel_size_temp, conv_channel_spat, kernel_size_spat, pooling_size, pool_stride_size, dropoutRate, n_classes, class_kernel_size) : super(Shallow_ConvNet, self).__init__() self.temp_conv = nn.Conv2d(in_channels=in_channel, out_channels=conv_channel_temp, kernel_size=(1, kernel_size_temp), stride=1, bias=False) self.spat_conv = nn.Conv2d(in_channels=conv_channel_temp, out_channels=conv_channel_spat, kernel_size=(kernel_size_spat, 1), stride=1, bias=False) self.bn = nn.BatchNorm2d(num_features=conv_channel_spat) # slef.act_conv = x*x self.pooling = nn.AvgPool2d(kernel_size=(1, pooling_size), stride=(1, pool_stride_size)) # slef.act_pool = log(max(x, eps)) self.dropout = nn.Dropout(p=dropoutRate) self.class_conv = nn.Conv2d(in_channels=conv_channel_spat, out_channels=n_classes, kernel_size=(1, class_kernel_size), bias=False) self.softmax = nn.Softmax(dim=1) def safe_log(self, x): """ Prevents :math:`log(0)` by using :math:`log(max(x, eps))`.""" return torch.log(torch.clamp(x, min=1e-6)) def forward(self, x): # input shape (batch_size, C, T) if len(x.shape) is not 4: x = torch.unsqueeze(x, 1) # input shape (batch_size, 1, C, T) x = self.temp_conv(x) x = self.spat_conv(x) x = self.bn(x) x = x*x # conv_activate x = self.pooling(x) x = self.safe_log(x) # pool_activate x = self.dropout(x) x = self.class_conv(x) x= self.softmax(x) out = torch.squeeze(x) return out ###============================ Initialization parameters ============================### channels = 44 samples = 534 in_channel = 1 conv_channel_temp = 40 kernel_size_temp = 25 conv_channel_spat = 40 kernel_size_spat = channels pooling_size = 75 pool_stride_size = 15 dropoutRate = 0.3 n_classes = 4 class_kernel_size = 30 def main(): input = torch.randn(32, 1, channels, samples) model = Shallow_ConvNet(in_channel, conv_channel_temp, kernel_size_temp, conv_channel_spat, kernel_size_spat, pooling_size, pool_stride_size, dropoutRate, n_classes, class_kernel_size) out = model(input) print('===============================================================') print('out', out.shape) print('model', model) summary(model=model, input_size=(1,channels,samples), batch_size=32, device="cpu") if __name__ == "__main__": main()
输出:
out torch.Size([32, 4]) model Shallow_ConvNet( (temp_conv): Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1), bias=False) (spat_conv): Conv2d(40, 40, kernel_size=(44, 1), stride=(1, 1), bias=False) (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (pooling): AvgPool2d(kernel_size=(1, 75), stride=(1, 15), padding=0) (dropout): Dropout(p=0.3, inplace=False) (class_conv): Conv2d(40, 4, kernel_size=(1, 30), stride=(1, 1), bias=False) (softmax): Softmax(dim=1) ) ---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [32, 40, 44, 510] 1,000 Conv2d-2 [32, 40, 1, 510] 70,400 BatchNorm2d-3 [32, 40, 1, 510] 80 AvgPool2d-4 [32, 40, 1, 30] 0 Dropout-5 [32, 40, 1, 30] 0 Conv2d-6 [32, 4, 1, 1] 4,800 Softmax-7 [32, 4, 1, 1] 0 ================================================================ Total params: 76,280 Trainable params: 76,280 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 2.87 Forward/backward pass size (MB): 229.69 Params size (MB): 0.29 Estimated Total Size (MB): 232.85 ----------------------------------------------------------------
旧的summary加入LSTM之类的会报错,需要用新的summarry
pip install torchinfo
from torchinfo import summary def main(): input = torch.randn(32, window_size, channels, samples) model = Cascade_Conv_LSTM(in_channel, out_channel_conv1, out_channel_conv2, out_channel_conv3, kernel_conv123, stride_conv123, padding_conv123, fc1_in, fc1_out, dropoutRate1, lstm1_in, lstm1_hidden, lstm1_layer, lstm2_in, lstm2_hidden, lstm2_layer, fc2_in, fc2_out, dropoutRate2, fc3_in, n_classes) # model = model.to('cuda:1') # input = torch.from_numpy(input).to('cuda:1').to(torch.float32).requires_grad_() out = model(input) print('===============================================================') print('out', out.shape) print('model', model) summary(model=model, input_size=(32,10,channels,samples), device="cpu") if __name__ == "__main__": main()
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== Cascade_Conv_LSTM [32, 4] -- ├─Sequential: 1-1 [320, 32, 10, 11] -- │ └─Conv2d: 2-1 [320, 32, 10, 11] 288 │ └─ELU: 2-2 [320, 32, 10, 11] -- ├─Sequential: 1-2 [320, 64, 10, 11] -- │ └─Conv2d: 2-3 [320, 64, 10, 11] 18,432 │ └─ELU: 2-4 [320, 64, 10, 11] -- ├─Sequential: 1-3 [320, 128, 10, 11] -- │ └─Conv2d: 2-5 [320, 128, 10, 11] 73,728 │ └─ELU: 2-6 [320, 128, 10, 11] -- ├─Sequential: 1-4 [320, 1024] -- │ └─Linear: 2-7 [320, 1024] 14,418,944 │ └─ELU: 2-8 [320, 1024] -- ├─Dropout: 1-5 [320, 1024] -- ├─LSTM: 1-6 [32, 10, 1024] 8,396,800 ├─LSTM: 1-7 [32, 10, 1024] 8,396,800 ├─Sequential: 1-8 [32, 1024] -- │ └─Linear: 2-9 [32, 1024] 1,049,600 │ └─ELU: 2-10 [32, 1024] -- ├─Dropout: 1-9 [32, 1024] -- ├─Linear: 1-10 [32, 4] 4,100 ├─Softmax: 1-11 [32, 4] -- ========================================================================================== Total params: 32,358,692 Trainable params: 32,358,692 Non-trainable params: 0 Total mult-adds (G): 13.28 ========================================================================================== Input size (MB): 0.14 Forward/backward pass size (MB): 71.21 Params size (MB): 129.43 Estimated Total Size (MB): 200.78 ==========================================================================================
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。