当前位置:   article > 正文

pytorch模型可视化的方法总结,(参数量内存,特征图内存),FLOPs和Parameters_summurry函数total mult-adds

summurry函数total mult-adds


这里主要介绍pytorch 模型的网络结构的可视化
以 SRCNN 为例子来说明可视化的方法,以及参数量的计算

模型所占内存 = (参数量内存,特征图内存),
模型计算量 = (浮点数计算量)

1. torchsummary

torchinfo

class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x


from torchinfo import summary
if __name__ == "__main__":
    modelviz = SRCNN()
    # 打印模型结构
    print(modelviz)
    summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])
    for p in modelviz.parameters():
        if p.requires_grad:
            print(p.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

可以得到的结果如下
在这里插入图片描述
具体什么含义呢?
接下来详细解释:
这里输入以 input_size=(8, 1, 8, 8) 为例子,
1) kernel shape 和 output shape 就是滤波器的参数shape 和 中间层的一些输出的 shape
2) Para # 表示的是有多少个参数,计算conv-2d 1-1的参数量,kernelshape = [9,9]:

W + b        = 5248
9*9*64 + 64 = 5248
  • 1
  • 2

3) Multi-Adds : 统计的是浮点数运算, 计算conv-2d 1-1的计算量(浮点数运算次数):
filter(h, w, bias, channel), input(h, w, channel)

(9*9 + 1) * 64                *    (8 * 8 * 8) = 2686976
  • 1

4) Total params, Total mult-adds (M) 就是对 上面参数的求和

比如 5248+51232+801 = 57281
  • 1

5)关于size:统计的是 参数 加上 中间层的 占用内存
输入内存Input size (MB): 0.00

8*1*8*8  * 4  / 10000008*1*8*8float,每个4Byte, 除以一百万  ,约等于 0    
  • 1

中间特征内存Forward/backward pass size (MB): 0.40

8*8*8 *1+64+64+32+32+1=  99328
99328 * 4 / 1000000 = 0.397312
  • 1
  • 2

参数weight内存Params size (MB): 0.23

57281*4 / 1000000 = 0.229124
  • 1

总内存Estimated Total Size (MB): 0.63

0.4 + 0.23
  • 1

2. graphviz, torchviz

from torchviz import make_dot
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modelviz = SRCNN().to(device)
input = torch.rand(8, 1, 8, 8).to(device)
out = modelviz(input)
print(out.shape)

# 1. 使用 torchviz 可视化
g = make_dot(out)
g.view()  # 直接在当前路径下保存 pdf 并打开
# g.render(filename='netStructure/myNetModel', view=False, format='pdf')  # 保存 pdf 到指定路径不打开
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

可视化结果是一个pdf,如下:写了比较多的步骤,所以网络结构感觉不是很清晰
在这里插入图片描述

3. 保存成pt文件后使用netron可视化

netron github:
安装:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple netron
  • 1

代码:

torch.save(modelviz, "modelviz.pt")

import netron
modelData = 'modelviz.pt'
netron.start(modelData)
  • 1
  • 2
  • 3
  • 4
  • 5

点击链接在浏览器中打开
在这里插入图片描述
在这里插入图片描述

4. tensorwatch

import tensorwatch as tw
# 3. 使用tensorwatch可视化
print(tw.model_stats(modelviz, (8, 1, 8, 8)))
tw.draw_model(modelviz, input)
  • 1
  • 2
  • 3
  • 4

打印的结果如图,可以和 summary 进行对比
在这里插入图片描述

5. get_model_complexity_info计算 FLOPs和parameters

    # 4. get_model_complexity_info
    from ptflops import get_model_complexity_info
    macs, params = get_model_complexity_info(modelviz, ( 1, 8, 8), verbose=True, print_per_layer_stat=True)
    print(macs, params)
    params = float(params[:-3])
    macs = float(macs[:-4])

    print(macs * 8, params) # 8个图像的 FLOPs, 这里的结果 和 其他方法应该一致
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

结果:
在这里插入图片描述

6. 附上直接可以执行的code

from torch import nn
import torch
from torchviz import make_dot
import tensorwatch as tw
from torchinfo import summary
import netron

class SRCNN(nn.Module):
    def __init__(self, num_channels=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x



if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = 'cpu'
    modelviz = SRCNN().to(device)
    # 打印模型结构
    print(modelviz)
    summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])
    for p in modelviz.parameters():
        if p.requires_grad:
            print(p.shape)
    # 创建输入, 看看输出结果

    input = torch.rand(8, 1, 8, 8).to(device)
    out = modelviz(input)
    print('out:', out.shape)
    # 1. 使用 torchviz 可视化
    g = make_dot(out)
    g.view()  # 直接在当前路径下保存 pdf 并打开
    # g.render(filename='netStructure/myNetModel', view=False, format='pdf')  # 保存 pdf 到指定路径不打开


    # 2. 保存成pt文件后进行可视化
    torch.save(modelviz, "modelviz.pt")
    modelData = 'modelviz.pt'
    netron.start(modelData)

    # 3. 使用tensorwatch可视化
    # print(tw.model_stats(modelviz, (8, 1, 8, 8)))
    # tw.draw_model(modelviz, input)


    # 4. get_model_complexity_info
    from ptflops import get_model_complexity_info
    macs, params = get_model_complexity_info(modelviz, (1, 8, 8), verbose=True, print_per_layer_stat=True)
    print(macs, params)
    params = float(params[:-3])
    macs = float(macs[:-4])

    print(macs * 8, params) # 8个图像的 FLOPs, 这里的结果 和 其他方法应该一致
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62

7. 参考

超实用的7种 pytorch 网络可视化方法,进来收藏一波
使用pytorchviz和Netron可视化pytorch网络结构

https://cloud.tencent.com/developer/article/1842049

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/514442
推荐阅读
相关标签
  

闽ICP备14008679号