当前位置:   article > 正文

神经网络可视化:Grad cam可视化_grad-cam可视化

grad-cam可视化

说明

Grad-CAM(Gradient-weighted Class Activation Mapping)是一种用于可视化卷积神经网络(CNN)模型的注意力区域的方法。它可以帮助我们理解模型的决策过程,即模型在图像中关注的区域。

一般步骤:

导入必要的库和模型:首先,你需要导入相关的库,如 PyTorch、NumPy 和 OpenCV,并加载已经训练好的 CNN 模型。

准备输入图像:选择一张输入图像作为输入,并将其进行预处理,使其符合模型的输入要求。

前向传播:将预处理后的图像输入到 CNN 模型中,进行前向传播,获取模型的输出。

计算梯度:根据模型的输出,计算目标类别对于特征图的梯度。

计算权重:根据梯度值,计算每个特征图通道的权重。

加权求和:将每个特征图通道与其对应的权重相乘,并将它们加权求和,得到最终的热力图。

可视化:将热力图与原始图像进行叠加或叠加显示,以可视化模型关注的区域。

安装所需要的库

pip install grad-cam

完整代码

import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget


def main():
    # 设置GPU设备
    torch.cuda.set_device(2) # 为了在GPU2上运行代码,您可以使用torch.cuda.set_device()函数将PyTorch设置为使用GPU2,并确保将use_cuda参数设置为True。如果只有一个GPU请设置为0

    model = models.resnet50(pretrained=True)
    model = model.cuda()  # 将模型移动到GPU2上
    target_layers = [model.layer4[-1]]

    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # 准备图像
    img_path = "img.png"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path).convert('RGB')
    img = np.array(img, dtype=np.uint8)
    img_tensor = data_transform(img)
    input_tensor = torch.unsqueeze(img_tensor, dim=0)
    input_tensor = input_tensor.cuda()  # 将输入张量移动到GPU2上

    # Grad CAM
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
    # targets = [ClassifierOutputTarget(281)]     # cat
    targets = [ClassifierOutputTarget(254)]  # dog

    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(img.astype(dtype=np.float32)/255.,
                                      grayscale_cam, use_rgb=True)

    plt.imshow(visualization)
    plt.show()


if __name__ == '__main__':
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

Alt
在这里插入图片描述

逐行解释

在这个部分,我们导入了所需的Python库和模块,包括NumPy、PyTorch、PIL、Matplotlib以及Grad-CAM相关的库。

import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

在main函数中,我们使用torch.cuda.set_device(2)将PyTorch设置为使用GPU2。
之后:我们加载了预训练的ResNet-50模型,并将其移动到GPU2上。

def main():
    # 设置GPU设备
    torch.cuda.set_device(2)
    model = models.resnet50(pretrained=True)
    model = model.cuda()  # 将模型移动到GPU2上
  • 1
  • 2
  • 3
  • 4
  • 5

我们加载了预训练的ResNet-50模型,并将其移动到GPU2上。

target_layers = [model.layer4[-1]]
  • 1

我们定义了一个数据转换的管道,其中包括将图像转换为张量、以及应用均值和标准差的归一化操作。

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
  • 1
  • 2
  • 3
  • 4

我们准备了一张图像,将其转换为张量,并将其移动到GPU2上。

cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
targets = [ClassifierOutputTarget(254)]  # dog
  • 1
  • 2

我们使用Grad-CAM生成类激活图(CAM),并将其应用于原始图像上,以可视化定位到的对象区域。

grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(img.astype(dtype=np.float32)/255.,
                                  grayscale_cam, use_rgb=True)
  • 1
  • 2
  • 3
  • 4
本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/340943
推荐阅读
相关标签
  

闽ICP备14008679号