赞
踩
目录
torchsummary感觉没有torchstat全,因此这里只演示torch
- #pip install torchstat #这是安装
-
- from torchstat import stat
-
- # 导入模型,输入一张输入图片的尺寸
- # 注意输入图的大小
- stat(model, (3, 48, 48))
-
- #coding:utf8
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torchviz import make_dot
- import numpy as np
-
- ## 简单模型定义
- class simpleconv3(nn.Module):
- def __init__(self,nclass):
- super(simpleconv3,self).__init__()
- self.conv1 = nn.Conv2d(3, 64, 3, 2, 1, bias=False)
- self.bn1 = nn.BatchNorm2d(64)
- self.conv2 = nn.Conv2d(64, 128, 3, 2, 1, bias=False)
- self.bn2 = nn.BatchNorm2d(128)
- self.conv3 = nn.Conv2d(128, 256, 3, 2, 1, bias=False)
- self.bn3 = nn.BatchNorm2d(256)
- self.fc = nn.Linear(256, nclass)
-
- def forward(self , x):
- x = F.relu(self.bn1(self.conv1(x)))
- x = F.relu(self.bn2(self.conv2(x)))
- x = F.relu(self.bn3(self.conv3(x)))
- print("*****1********",x.shape) #(1, 256, 6, 6)
- x = nn.AvgPool2d(6)(x) #平均池化 池化核为6*6
- print("*****2********",x.shape) #(1, 256, 1, 1)
- x = x.view(x.size(0), -1)
- print("*****3********",x.shape) #(1, 256)
- x = self.fc(x) #输入必须和256对应上
- print("*****4********",x.shape) #(1, 4)
- return x
-
- if __name__ == '__main__':
- import torch
- from torch.autograd import Variable
- x = Variable(torch.randn(1,3,48,48))
- model = simpleconv3(4)
- y = model(x)
- g = make_dot(y)
- g.view()
- *****1******** torch.Size([1, 256, 6, 6])
- *****2******** torch.Size([1, 256, 1, 1])
- *****3******** torch.Size([1, 256])
- *****4******** torch.Size([1, 4])
- module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
- 0 conv1 3 48 48 64 24 24 1728.0 0.14 1,953,792.0 995,328.0 34560.0 147456.0 24.95% 182016.0
- 1 bn1 64 24 24 64 24 24 128.0 0.14 147,456.0 73,728.0 147968.0 147456.0 6.52% 295424.0
- 2 conv2 64 24 24 128 12 12 73728.0 0.07 21,215,232.0 10,616,832.0 442368.0 73728.0 36.43% 516096.0
- 3 bn2 128 12 12 128 12 12 256.0 0.07 73,728.0 36,864.0 74752.0 73728.0 2.60% 148480.0
- 4 conv3 128 12 12 256 6 6 294912.0 0.04 21,224,448.0 10,616,832.0 1253376.0 36864.0 23.31% 1290240.0
- 5 bn3 256 6 6 256 6 6 512.0 0.04 36,864.0 18,432.0 38912.0 36864.0 2.07% 75776.0
- 6 fc 256 4 1028.0 0.00 2,044.0 1,024.0 5136.0 16.0 4.11% 5152.0
- total 372292.0 0.49 44,653,564.0 22,359,040.0 5136.0 16.0 99.99% 2513184.0
- ============================================================================================================================================
- Total params: 372,292
- --------------------------------------------------------------------------------------------------------------------------------------------
- Total memory: 0.49MB
- Total MAdd: 44.65MMAdd
- Total Flops: 22.36MFlops
- Total MemR+W: 2.4MB
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。