当前位置:   article > 正文

pytorch模型可视化_directed graph pytorch

directed graph pytorch

1. 使用dot

1.1 安装graphviz和torchviz

sudo apt-get install graphviz

sudo pip install torchviz
  • 1
  • 2
  • 3

1.2 使用torchviz

import torch
from torch import nn
from torchviz import make_dot, make_dot_from_trace

# Visualize gradients of simple MLP
# The method below is for building directed graphs of PyTorch operations, built during forward propagation and showing which operations will be called on backward. It omits subgraphs which do not require gradients.

model=nn.Sequential()
model.add_module("W0", nn.Linear(8, 16))
model.add_module("tanh", nn.Tanh())
model.add_module("W1", nn.Linear(16, 1))

x = Variable(torch.randn(1, 8))
y = model(x)

make_dot(y.mean(), params=dict(model.named_parameters()))  # 直接在ipython notebook中显示

dot=make_dot(y.mean(), params=dict(model.named_parameters()))
dot.render("model.pdf")  #保存为pdf
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

github教程

2. 使用tensorwatch

安装pytorch = 1.2, tensorwatch = 0.8.7

import tensorwatch as tw
import torchvision.models

alexnet_model = torchvision.models.alexnet()
tw.draw_model(alexnet_model, [1, 3, 224, 224]).save("alextnet.png")
  • 1
  • 2
  • 3
  • 4
  • 5

教程: tensorwatch

3. 使用hiddenlayer

pip install torch===1.4.0 torchvision===0.5.0 -f https://download.pytorch.org/whl/torch_stable.html

pip install hiddenlayer
  • 1
  • 2
  • 3

import torch
import torchvision.models
import hiddenlayer as hl

# Resnet101
model = torchvision.models.resnet101()

# Rather than using the default transforms, build custom ones to group
# nodes of residual and bottleneck blocks.
transforms = [
    # Fold Conv, BN, RELU layers into one
    hl.transforms.Fold("Conv > BatchNorm > Relu", "ConvBnRelu"),
    # Fold Conv, BN layers together
    hl.transforms.Fold("Conv > BatchNorm", "ConvBn"),
    # Fold bottleneck blocks
    hl.transforms.Fold("""
        ((ConvBnRelu > ConvBnRelu > ConvBn) | ConvBn) > Add > Relu
        """, "BottleneckBlock", "Bottleneck Block"),
    # Fold residual blocks
    hl.transforms.Fold("""ConvBnRelu > ConvBnRelu > ConvBn > Add > Relu""",
                       "ResBlock", "Residual Block"),
    # Fold repeated blocks
    hl.transforms.FoldDuplicates(),
]

# Display graph using the transforms above
dot = hl.build_graph(model, torch.zeros([1, 3, 224, 224]), transforms=transforms)
dot.attr("graph", rankdir="TD")
dot.render("resnet101")  # save as resnet101.pdf
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

教程: hiddenlayer

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/514449
推荐阅读
相关标签
  

闽ICP备14008679号