当前位置:   article > 正文

对vit(Vision Transformer)的注意力可视化。使用grad_cam方法_vit的注意力可视化

vit的注意力可视化

一、环境准备

注意安装包是pip install grad_cam而不是pytorch_grad_cam。一个是包名一个是导入名。之前发现怎么都装不上。

pip install "grad-cam==1.4.0"
  • 1

导入时调用pytorch_grad_cam

```python
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus,AblationCAM, \
                            XGradCAM, EigenCAM, EigenGradCAM,LayerCAM,FullGrad
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
import cv2
import numpy as np
import torch
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

二、加载预训练的vit模型(离线加载或在线加载)

因为网络问题,使用离线定义网络与离线加载模型方法,也可以在线torch.hub.load加载

#离线模型,模型定义具体省略
my_model =  models.init_models(myargs)#省略
model_pkl = "******/dino_finetune.pkl"#加载自己训练好的模型
my_model.load_state_dict(torch.load(model_pkl))
my_model.eval()
##在线模型加载
#my_model = torch.hub.load('facebookresearch/deit:main','deit_tiny_patch16_224', #pretrained=True)
#my_model.eval()
# 判断是否使用 GPU 加速
use_cuda = torch.cuda.is_available()
if use_cuda:
    my_model = my_model.cuda() #如果是gpu的话加速
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

三、选择目标层来计算grad_cam,由于 ViT 的最后一层只有类别标记对预测类别有影响,所以我们不能选择最后一层。我们可以选择倒数第二层中的任意一个 Transformer 编码器作为目标层。

#首先定义函数对vit输出的3维张量转换为传统卷积处理时的二维张量,gradcam需要。
#(B,H*W,feat_dim)转换到(B,C,H,W),其中H*W是分pathc数。具体参数根据自己模型情况
#我的输入为224*224,pathsize为(16*16),那么我的(H,W)就是(224/16,224/16),即14*14
def reshape_transform(tensor, height=14, width=14):
    # 去掉cls token
    result = tensor[:, 1:, :].reshape(tensor.size(0),
    height, width, tensor.size(2))
    # 将通道维度放到第一个位置
    result = result.transpose(2, 3).transpose(1, 2)
    return result
    
# 创建 GradCAM 对象
cam = GradCAM(model=model,
            target_layers=[model.blocks[-1].norm1],
            # 这里的target_layer要看模型情况,调试时自己打印下model吧
            # 比如还有可能是:target_layers = [model.blocks[-1].ffn.norm]
            # 或者target_layers = [model.blocks[-1].ffn.norm]
            use_cuda=use_cuda,
            reshape_transform=reshape_transform)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

四、读入图片,预处理后送入网络。调用 cam 对象的 forward 方法,传入输入张量和预测类别(如果不指定,则默认为最高概率的类别),得到 Grad-CAM 的输出

# 读取输入图像
image_path = "xxx.jpg"
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
rgb_img = cv2.resize(rgb_img, (224, 224))

# 预处理图像
input_tensor = preprocess_image(rgb_img,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

# 看情况将图像转换为批量形式
# input_tensor = input_tensor.unsqueeze(0)
if use_cuda:
    input_tensor = input_tensor.cuda()

# 计算 grad-cam
target_category = None # 可以指定一个类别,或者使用 None 表示最高概率的类别
grayscale_cam = cam(input_tensor=input_tensor, targets=target_category)
grayscale_cam = grayscale_cam[0, :]

# 将 grad-cam 的输出叠加到原始图像上
#visualization = show_cam_on_image(rgb_img, grayscale_cam),借鉴的代码rgb格式不对,换成下面
visualization = show_cam_on_image(rgb_img.astype(dtype=np.float32)/255,grayscale_cam)

# 保存可视化结果
cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR, visualization)##注意自己图像格式,吐过本身就是BGR,就不用这行

cv2.imwrite('cam.jpg', visualization)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

注意

报错一:如果grad-cam版本过高,会报错Grad-cam报错AttributeError: ‘GradCAM‘ object has no attribute ‘activations_and_grads‘:所以装1.4版本。

报错二、gradcam报错2 AttributeError: ‘list’ object has no attribute ‘cpu’,是因为grad_cam在通过分类层结果确认梯度贡献,而我的代码是做识别方向的,return的feat特征,而不是cls分类。改变网络最后return为分类层输出即可。同时如果用了circleproduct还要注意测试时是否传了label进入网络

报错三、RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn。是因为我为了省事,在工程文件的测试代码中加grad_cam可视化,然后with torch.no_grad()忘记注释了,导致grad_cam计算时没有梯度而报错。注释即可

注意是否红色和蓝色区域互换了,红色应该是注意力地方,如果反过来了,那就是图片rgb和bgr格式问题了

 cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB, visualization)
  • 1

#代码借鉴于:https://zhuanlan.zhihu.com/p/640450435

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/764382
推荐阅读
相关标签
  

闽ICP备14008679号