当前位置:   article > 正文

利用HiddenLayer和netron进行pytorch模型结构可视化

hiddenlayer

0、简介

模型可视化是通过直观方式查看我们模型的结构。通常我们使用pytorch定义的网络模型都是代码堆叠,实现的和我们想象的是否一致呢,除了细致推敲代码外,直接通过图的方式展示出来更加直观。在这里介绍HiddenLayer和netron进行模型可视化,HiddenLayer是可以直接对pt模型进行可视化的,而netron无法直接可视化pt模型,所以我们通过将pt转为onnx模型,再通过netron进行可视化。

1、利用HiddenLayer进行模型可视化

模型可视化的方法有很多,可以看看这篇文章:超实用的7种 pytorch 网络可视化方法,进来收藏一波

这里记录一下HiddenLayer这个工具的使用,先看效果图:
在这里插入图片描述
相比较于其他工具,这个库非常简介,并且只包含给人看的节点,还能展示输入输出的shape,非常的人性化。

首先在环境中安装:pip install hiddenlayer
然后使用代码如下:

import torch
import hiddenlayer as h
from torchvision.models import resnet18

myNet = resnet18()  # 实例化 resnet18
x = torch.zeros(16, 3, 64, 64)  # 随机生成一个输入

myNetGraph = h.build_graph(myNet, x)  # 建立网络模型图
# myNetGraph.theme = h.graph.THEMES['blue']  # blue 和 basic 两种颜色,可以不要
myNetGraph.save(path='./demoModel.png', format='png')  # 保存网络模型图,可以设置 png 和 pdf 等
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

问题
我遇到的是不显示shape并出现警告
Pango-WARNING **: couldn’t load font “Times Not-Rotated 10”, falling back to “Sans Not-Rotated 10”, expect ugly output.
解决办法:

其他问题,参见:
pytorch 网络可视化(六):hiddenlayer
hiddenlayer库使用出现的一系列问题

2、使用netron进行模型可视化

首先我们需要安装netron:pip install netron
然后使用代码如下:

import netron
import torch
from torch import nn

# 定义我们的模型
class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.block1 = nn.Sequential(nn.Conv2d(3, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))
        self.block2 = nn.Sequential(nn.Conv2d(10, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))
        self.block3 = nn.Sequential(nn.Conv2d(10, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

# 进行模型推理(pt转onnx是通过跟踪计算流实现的,所以需要先推理一下)
net = TestNet()
input = torch.rand([1, 3, 10, 10])
output = net(input)

# 转为onnx模型
torch.onnx.export(net, input, "testnet.onnx", opset_version=11)
netron.start("testnet.onnx")	# 使用netron可视化onnx模型
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

执行代码会自动弹出web页面:
在这里插入图片描述

3、(高阶)使用netron进行模型可视化

通过上面的方法我们实现了模型的可视化,但是这个模型比较简单,如果来个复杂的模型,那么这个图就会很大很复杂,以至于我们都分不清和TestNet中的对应关系。

TestNet中的conv算子会转换成onnx中的conv算子,那意味着我们可以设计一个特殊的算子,暂且命名为DebugOp,转为onnx后在图中也会出现DebugOp算子,通过找到这个算子就能大致和TestNet进行关系对应。

具体代码如下:

import netron
import torch
from torch import nn

# 这里就是我们定义的特殊算子DebugOP
class DebugOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, name):  # 这个DebugOp算子将输入x直接返回,不做任何云算,插入到网络中也就不改变网络结构
        return x
    @staticmethod
    def symbolic(g, x, name):
        return g.op("my::Debug", x, name_s=name)
# 获取自定义算子的调用接口(用法上相当于实例化),后面就可以用debug_apply(x,name进行使用),在不同的地方可以传入不同的name
debug_apply = DebugOp.apply


class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.block1 = nn.Sequential(nn.Conv2d(3, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))
        self.block2 = nn.Sequential(nn.Conv2d(10, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))
        self.block3 = nn.Sequential(nn.Conv2d(10, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))

    def forward(self, x):
        x = debug_apply(x, "this is block1")	# 将我们的特殊算子插入到网络中
        x = self.block1(x)
        x = debug_apply(x, "this is block2")
        x = self.block2(x)
        x = debug_apply(x, "this is block3")
        x = self.block3(x)
        return x


net = TestNet()
input = torch.rand([1, 3, 10, 10])
output = net(input)
torch.onnx.export(net, input, "testnet1.onnx", opset_version=11)
netron.start("testnet1.onnx")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

可视化结构如下,插入了很多debug算子,并且点击查看算子属性可以看到name是我们传入的name,于是我们就能够很清楚知道下面哪一部分是block1、block2和block3了,这在复杂网络结构中寻找某些层是非常有用的。
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/656114
推荐阅读
相关标签
  

闽ICP备14008679号