当前位置:   article > 正文

pytorch可视化CNN每层的特征_pytorch 可视化每层

pytorch 可视化每层

在PyTorch中,可以使用torchvision.utils.make_grid来将特征图可视化为一个网格。具体步骤如下:
1.定义一个数据集并加载数据

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, 
                                 transform=transforms.ToTensor(), download=True)
# 加载数据
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

2.加载预训练模型并定义一个函数,用于提取每一层的特征

import torch.nn as nn
import torchvision.models as models

# 加载预训练模型
model = models.resnet18(pretrained=True)

# 定义一个函数,用于提取每一层的特征
def get_features(x, model, layers):
    features = []
    for name, module in model._modules.items():
        x = module(x)
        if name in layers:
            features.append(x)
    return features

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

3.获取特定层的名称并将数据集中的一批数据输入到模型中,获取对应层的特征

# 获取模型中所有层的名称
all_layers = []
for name, layer in model.named_modules():
    all_layers.append(name)

# 获取需要可视化的层的名称
layers = all_layers[4:9]

# 获取一批数据
data, _ = next(iter(train_loader))

# 将数据输入到模型中,并获取对应层的特征
features = get_features(data, model, layers)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

4.将特征可视化为网格

import matplotlib.pyplot as plt
import numpy as np

# 定义一个函数,用于将特征可视化为网格
def visualize_features(features):
    nrow = len(features)
    ncols = features[0].shape[1]
    fig, axs = plt.subplots(nrow, ncols, figsize=(10, 10))

    for i in range(nrow):
        for j in range(ncols):
            img = features[i][0][j].detach().numpy()
            img = np.transpose(img, (1, 2, 0))
            img = (img - img.min()) / (img.max() - img.min())
            axs[i][j].imshow(img)
            axs[i][j].axis('off')
            if j == 0:
                axs[i][j].set_title(layers[i])
    plt.show()

# 可视化特征
visualize_features(features)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

以上代码会将提取的特定层的特征可视化为一个网格,并在网格的左侧显示对应层的名称。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/717066
推荐阅读
相关标签
  

闽ICP备14008679号