当前位置:   article > 正文

pytorch实现特征图可视化,代码简洁,包教包会_特征图可视化代码

特征图可视化代码

是不是要这样的效果
在这里插入图片描述
在这里插入图片描述
技术要点 1.选择一层网络,将图片的tensor放进去 2.将网络的输出plt.imshow

代码可直接复制使用,需要改的就是你的图片位置

import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc
plt.rcParams['font.sans-serif']=['STSong']
import torchvision.models as models
model = models.alexnet(pretrained=True)

#1.模型查看
# print(model)#可以看出网络一共有3层,两个Sequential()+avgpool
# model_features = list(model.children())
# print(model_features[0][3])#取第0层Sequential()中的第四层
# for index,layer in enumerate(model_features[0]):
#     print(layer)


#2. 导入数据
# 以RGB格式打开图像
# Pytorch DataLoader就是使用PIL所读取的图像格式
# 建议就用这种方法读取图像,当读入灰度图像时convert('')
def get_image_info(image_dir):
    image_info = Image.open(image_dir).convert('RGB')#是一幅图片
    # 数据预处理方法
    image_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image_info = image_transform(image_info)#torch.Size([3, 224, 224])
    image_info = image_info.unsqueeze(0)#torch.Size([1, 3, 224, 224])因为model的输入要求是4维,所以变成4维
    return image_info#变成tensor数据


#2. 获取第k层的特征图
'''
args:
k:定义提取第几层的feature map
x:图片的tensor
model_layer:是一个Sequential()特征层
'''
def get_k_layer_feature_map(model_layer, k, x):
    with torch.no_grad():
        for index, layer in enumerate(model_layer):#model的第一个Sequential()是有多层,所以遍历
            x = layer(x)#torch.Size([1, 64, 55, 55])生成了64个通道
            if k == index:
                return x


#  可视化特征图
def show_feature_map(feature_map):#feature_map=torch.Size([1, 64, 55, 55]),feature_map[0].shape=torch.Size([64, 55, 55])
                                                                         # feature_map[2].shape     out of bounds
    feature_map = feature_map.squeeze(0)#压缩成torch.Size([64, 55, 55])
    
    #以下4行,通过双线性插值的方式改变保存图像的大小
    feature_map =feature_map.view(1,feature_map.shape[0],feature_map.shape[1],feature_map.shape[2])#(1,64,55,55)
    upsample = torch.nn.UpsamplingBilinear2d(size=(256,256))#这里进行调整大小
    feature_map = upsample(feature_map)
    feature_map = feature_map.view(feature_map.shape[1],feature_map.shape[2],feature_map.shape[3])
    
    feature_map_num = feature_map.shape[0]#返回通道数
    row_num = np.ceil(np.sqrt(feature_map_num))#8
    plt.figure()
    for index in range(1, feature_map_num + 1):#通过遍历的方式,将64个通道的tensor拿出

        plt.subplot(row_num, row_num, index)
        plt.imshow(feature_map[index - 1], cmap='gray')#feature_map[0].shape=torch.Size([55, 55])
        #将上行代码替换成,可显示彩色 plt.imshow(transforms.ToPILImage()(feature_map[index - 1]))#feature_map[0].shape=torch.Size([55, 55])
        plt.axis('off')
        scipy.misc.imsave( 'feature_map_save//'+str(index) + ".png", feature_map[index - 1])
    plt.show()



if __name__ ==  '__main__':

    image_dir = r"car_logol.png"
    # 定义提取第几层的feature map
    k = 0
    image_info = get_image_info(image_dir)

    model = models.alexnet(pretrained=True)
    model_layer= list(model.children())
    model_layer=model_layer[0]#这里选择model的第一个Sequential()

    feature_map = get_k_layer_feature_map(model_layer, k, image_info)
    show_feature_map(feature_map)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89

通过上采样,改变输出图像大小
看57行代码注释,保存图像大小可以改动(回应多位网友需求)

彩色图显示

直接看70行注释
#在show_feature_map函数中加上一句,tensor数据变成Img的操作
image_PIL=transforms.ToPILImage()(feature_map[index - 1])
  • 1
  • 2
  • 3

在这里插入图片描述

如果对于matplotlib不熟练
matplotlib绘制多个子图(汉字标题,XY轴标签)& PIL.Image 11行读取文件夹中照片

碰巧,如果你看到了这篇文章,并且觉得有用的话 那就给个三连吧!
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/340886
推荐阅读
相关标签
  

闽ICP备14008679号