当前位置:   article > 正文

【笔记 Pytorch】模型网络结构、网络参数可视化_pytorch打印网络结构

pytorch打印网络结构

查看网络结构

打印方式

torchsummary 方式(输入格式不好控制)

参考网址

import torch
import torchvision
from torchsummary import summary          #使用 pip install torchsummary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg = torchvision.models.vgg16().to(device)

# summary(your_model, input_size=(channels, H, W))
summary(vgg, input_size=(3, 224, 224))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

print方式 (简便,存在输出顺序与执行顺序不一致的问题)

for name, parameters in your_model.named_parameters():
    print(name, ':', parameters.size())
  • 1
  • 2

可视化方式

HiddenLayer

 pip install hiddenlayer
  • 1
 import hiddenlayer as h
 vis_graph = h.build_graph(MyConvNet, torch.zeros([1 ,1, 28, 28]))   # 获取绘制图像的对象
 vis_graph.theme = h.graph.THEMES["blue"].copy()     # 指定主题颜色
 vis_graph.save("./demo1.png")   # 保存图像的路径
  • 1
  • 2
  • 3
  • 4

PytorchVIZ

 pip install torchviz
  • 1
 from torchviz import make_dot
 x = torch.randn(1, 1, 28, 28).requires_grad_(True)  # 定义一个网络的输入值
 y = MyConvNet(x)    # 获取网络的预测值
 ​
 MyConvNetVis = make_dot(y, params=dict(list(MyConvNet.named_parameters()) + [('x', x)]))
 MyConvNetVis.format = "png"
 # 指定文件生成的文件夹
 MyConvNetVis.directory = "data"
 # 生成文件
 MyConvNetVis.view()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

tensorboardX(会存在一些版本的匹配问题,不太直观)

graphviz + torchviz (依赖于graphviz和GitHub第三方库torchviz)

微软的tensorwatch (只能在jupyter notebook中使用)

netron可视化工具(.pt 或者是 .pth 文件)

查看网络参数

params = list(model.parameters())
k = 0
for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
                l *= j
        print("该层参数和:" + str(l))
        k = k + l
print("总参数数量和:" + str(k))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/507936
推荐阅读
相关标签
  

闽ICP备14008679号