当前位置:   article > 正文

torch提供的参数量等指标计算工具torchstat和torchsummary

torchstat

目录

安装使用

定义一个简单的网络

输出


torchsummary感觉没有torchstat全,因此这里只演示torch

安装使用

  1. #pip install torchstat #这是安装
  2. from torchstat import stat
  3. # 导入模型,输入一张输入图片的尺寸
  4. # 注意输入图的大小
  5. stat(model, (3, 48, 48))

定义一个简单的网络

  1. #coding:utf8
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torchviz import make_dot
  6. import numpy as np
  7. ## 简单模型定义
  8. class simpleconv3(nn.Module):
  9. def __init__(self,nclass):
  10. super(simpleconv3,self).__init__()
  11. self.conv1 = nn.Conv2d(3, 64, 3, 2, 1, bias=False)
  12. self.bn1 = nn.BatchNorm2d(64)
  13. self.conv2 = nn.Conv2d(64, 128, 3, 2, 1, bias=False)
  14. self.bn2 = nn.BatchNorm2d(128)
  15. self.conv3 = nn.Conv2d(128, 256, 3, 2, 1, bias=False)
  16. self.bn3 = nn.BatchNorm2d(256)
  17. self.fc = nn.Linear(256, nclass)
  18. def forward(self , x):
  19. x = F.relu(self.bn1(self.conv1(x)))
  20. x = F.relu(self.bn2(self.conv2(x)))
  21. x = F.relu(self.bn3(self.conv3(x)))
  22. print("*****1********",x.shape) #(1, 256, 6, 6)
  23. x = nn.AvgPool2d(6)(x) #平均池化 池化核为6*6
  24. print("*****2********",x.shape) #(1, 256, 1, 1)
  25. x = x.view(x.size(0), -1)
  26. print("*****3********",x.shape) #(1, 256)
  27. x = self.fc(x) #输入必须和256对应上
  28. print("*****4********",x.shape) #(1, 4)
  29. return x
  30. if __name__ == '__main__':
  31. import torch
  32. from torch.autograd import Variable
  33. x = Variable(torch.randn(1,3,48,48))
  34. model = simpleconv3(4)
  35. y = model(x)
  36. g = make_dot(y)
  37. g.view()

输出

  1. *****1******** torch.Size([1, 256, 6, 6])
  2. *****2******** torch.Size([1, 256, 1, 1])
  3. *****3******** torch.Size([1, 256])
  4. *****4******** torch.Size([1, 4])
  5. module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
  6. 0 conv1 3 48 48 64 24 24 1728.0 0.14 1,953,792.0 995,328.0 34560.0 147456.0 24.95% 182016.0
  7. 1 bn1 64 24 24 64 24 24 128.0 0.14 147,456.0 73,728.0 147968.0 147456.0 6.52% 295424.0
  8. 2 conv2 64 24 24 128 12 12 73728.0 0.07 21,215,232.0 10,616,832.0 442368.0 73728.0 36.43% 516096.0
  9. 3 bn2 128 12 12 128 12 12 256.0 0.07 73,728.0 36,864.0 74752.0 73728.0 2.60% 148480.0
  10. 4 conv3 128 12 12 256 6 6 294912.0 0.04 21,224,448.0 10,616,832.0 1253376.0 36864.0 23.31% 1290240.0
  11. 5 bn3 256 6 6 256 6 6 512.0 0.04 36,864.0 18,432.0 38912.0 36864.0 2.07% 75776.0
  12. 6 fc 256 4 1028.0 0.00 2,044.0 1,024.0 5136.0 16.0 4.11% 5152.0
  13. total 372292.0 0.49 44,653,564.0 22,359,040.0 5136.0 16.0 99.99% 2513184.0
  14. ============================================================================================================================================
  15. Total params: 372,292
  16. --------------------------------------------------------------------------------------------------------------------------------------------
  17. Total memory: 0.49MB
  18. Total MAdd: 44.65MMAdd
  19. Total Flops: 22.36MFlops
  20. Total MemR+W: 2.4MB

计算公式请参考:计算VGG16的参数量_RayChiu_Labloy的博客-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/974327
推荐阅读
相关标签
  

闽ICP备14008679号