赞
踩
模型所占内存 = (参数量内存,特征图内存),
模型计算量 = (浮点数计算量)
class SRCNN(nn.Module): def __init__(self, num_channels=1): super(SRCNN, self).__init__() self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2) self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2) self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.conv3(x) return x from torchinfo import summary if __name__ == "__main__": modelviz = SRCNN() # 打印模型结构 print(modelviz) summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"]) for p in modelviz.parameters(): if p.requires_grad: print(p.shape)
可以得到的结果如下
具体什么含义呢?
接下来详细解释:
这里输入以 input_size=(8, 1, 8, 8) 为例子,
1) kernel shape 和 output shape 就是滤波器的参数shape 和 中间层的一些输出的 shape
2) Para # 表示的是有多少个参数,计算conv-2d 1-1的参数量,kernelshape = [9,9]:
W + b = 5248
9*9*64 + 64 = 5248
3) Multi-Adds : 统计的是浮点数运算, 计算conv-2d 1-1的计算量(浮点数运算次数):
filter(h, w, bias, channel), input(h, w, channel)
(9*9 + 1) * 64 * (8 * 8 * 8) = 2686976
4) Total params, Total mult-adds (M) 就是对 上面参数的求和
比如 5248+51232+801 = 57281
5)关于size:统计的是 参数 加上 中间层的 占用内存
输入内存Input size (MB): 0.00
是 8*1*8*8 * 4 / 1000000, 8*1*8*8 个float,每个4Byte, 除以一百万 ,约等于 0
中间特征内存Forward/backward pass size (MB): 0.40
8*8*8 * (1+64+64+32+32+1) = 99328
99328 * 4 / 1000000 = 0.397312
参数weight内存Params size (MB): 0.23
57281*4 / 1000000 = 0.229124
总内存Estimated Total Size (MB): 0.63
0.4 + 0.23
from torchviz import make_dot
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modelviz = SRCNN().to(device)
input = torch.rand(8, 1, 8, 8).to(device)
out = modelviz(input)
print(out.shape)
# 1. 使用 torchviz 可视化
g = make_dot(out)
g.view() # 直接在当前路径下保存 pdf 并打开
# g.render(filename='netStructure/myNetModel', view=False, format='pdf') # 保存 pdf 到指定路径不打开
可视化结果是一个pdf,如下:写了比较多的步骤,所以网络结构感觉不是很清晰
netron github:
安装:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple netron
代码:
torch.save(modelviz, "modelviz.pt")
import netron
modelData = 'modelviz.pt'
netron.start(modelData)
点击链接在浏览器中打开
import tensorwatch as tw
# 3. 使用tensorwatch可视化
print(tw.model_stats(modelviz, (8, 1, 8, 8)))
tw.draw_model(modelviz, input)
打印的结果如图,可以和 summary 进行对比
# 4. get_model_complexity_info
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(modelviz, ( 1, 8, 8), verbose=True, print_per_layer_stat=True)
print(macs, params)
params = float(params[:-3])
macs = float(macs[:-4])
print(macs * 8, params) # 8个图像的 FLOPs, 这里的结果 和 其他方法应该一致
结果:
from torch import nn import torch from torchviz import make_dot import tensorwatch as tw from torchinfo import summary import netron class SRCNN(nn.Module): def __init__(self, num_channels=1): super(SRCNN, self).__init__() self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2) self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2) self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.conv3(x) return x if __name__ == "__main__": device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # device = 'cpu' modelviz = SRCNN().to(device) # 打印模型结构 print(modelviz) summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"]) for p in modelviz.parameters(): if p.requires_grad: print(p.shape) # 创建输入, 看看输出结果 input = torch.rand(8, 1, 8, 8).to(device) out = modelviz(input) print('out:', out.shape) # 1. 使用 torchviz 可视化 g = make_dot(out) g.view() # 直接在当前路径下保存 pdf 并打开 # g.render(filename='netStructure/myNetModel', view=False, format='pdf') # 保存 pdf 到指定路径不打开 # 2. 保存成pt文件后进行可视化 torch.save(modelviz, "modelviz.pt") modelData = 'modelviz.pt' netron.start(modelData) # 3. 使用tensorwatch可视化 # print(tw.model_stats(modelviz, (8, 1, 8, 8))) # tw.draw_model(modelviz, input) # 4. get_model_complexity_info from ptflops import get_model_complexity_info macs, params = get_model_complexity_info(modelviz, (1, 8, 8), verbose=True, print_per_layer_stat=True) print(macs, params) params = float(params[:-3]) macs = float(macs[:-4]) print(macs * 8, params) # 8个图像的 FLOPs, 这里的结果 和 其他方法应该一致
超实用的7种 pytorch 网络可视化方法,进来收藏一波
使用pytorchviz和Netron可视化pytorch网络结构
https://cloud.tencent.com/developer/article/1842049
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。