赞
踩
CNN中的特征可视化大体可分为两类:
第一类方法只显示了在深层特征中保留了哪些信息,而没有突出显示这些信息的相对重要性。第二类方法则具有一定的解释性,例如在分类任务中,通过CAM能够解释模型究竟是通过重点学习输入图像中的哪些信息来判断类别的。
Network in Network中提出了用全局平均池化(GAP)替代全连接层以加强特征映射与类别之间的联系,更具可解释性。受该思想启发,CAM可视化技术应运而出。生成CAM的流程如下图所示(论文原图):
可以看出,生成CAM的步骤非常简单,但是对网络结构有要求(网络末端为GAP+FC这样的结构,并且FC只有一层,用于输出类别概率)。假设分类任务采用的是VGG网络,此时生成CAM的步骤为:
不难发现,若网络结构不符合要求,按照上述方法计算CAM需要修改网络结构和重新训练。针对该问题,后续研究中提出了Gard-CAM。
由上述CAM的计算方法可知,生成CAM的关键是获取特征图的权重。基于对原始CAM的改进,Grad-CAM通过求网络输出的类别置信度对特征图的偏导来获取权重,适用于任意网络,并且能够可视化任意层的类激活图(通常选择最后一个卷积层,因为其包含了丰富的高级语义和空间信息)。
import numpy as np import torch import cv2 import matplotlib.pyplot as plt import torchvision.models as models from torchvision.transforms import Compose, Normalize, ToTensor class GradCAM(): ''' Grad-cam: Visual explanations from deep networks via gradient-based localization Selvaraju R R, Cogswell M, Das A, et al. https://openaccess.thecvf.com/content_iccv_2017/html/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.html ''' def __init__(self, model, target_layers, use_cuda=True): super(GradCAM).__init__() self.use_cuda = use_cuda self.model = model self.target_layers = target_layers self.target_layers.register_forward_hook(self.forward_hook) self.target_layers.register_full_backward_hook(self.backward_hook) self.activations = [] self.grads = [] def forward_hook(self, module, input, output): self.activations.append(output[0]) def backward_hook(self, module, grad_input, grad_output): self.grads.append(grad_output[0].detach()) def calculate_cam(self, model_input): if self.use_cuda: device = torch.device('cuda') self.model.to(device) # Module.to() is in-place method model_input = model_input.to(device) # Tensor.to() is not a in-place method self.model.eval() # forward y_hat = self.model(model_input) max_class = np.argmax(y_hat.cpu().data.numpy(), axis=1) # backward model.zero_grad() y_c = y_hat[0, max_class] y_c.backward() # get activations and gradients activations = self.activations[0].cpu().data.numpy().squeeze() grads = self.grads[0].cpu().data.numpy().squeeze() # calculate weights weights = np.mean(grads.reshape(grads.shape[0], -1), axis=1) weights = weights.reshape(-1, 1, 1) cam = (weights * activations).sum(axis=0) cam = np.maximum(cam, 0) # ReLU cam = cam / cam.max() return cam @staticmethod def show_cam_on_image(image, cam): # image: [H,W,C] h, w = image.shape[:2] cam = cv2.resize(cam, (h,w)) cam = cam / cam.max() heatmap = cv2.applyColorMap((255*cam).astype(np.uint8), cv2.COLORMAP_JET) # [H,W,C] heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) image = image / image.max() heatmap = heatmap / heatmap.max() result = 0.4*heatmap + 0.6*image result = result / result.max() plt.figure() plt.imshow((result*255).astype(np.uint8)) plt.colorbar(shrink=0.8) plt.tight_layout() plt.show() @staticmethod def preprocess_image(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): preprocessing = Compose([ ToTensor(), Normalize(mean=mean, std=std) ]) return preprocessing(img.copy()).unsqueeze(0) if __name__ == '__main__': image = cv2.imread('both.png') # (224,224,3) input_tensor = GradCAM.preprocess_image(image) model = models.resnet18(pretrained=True) grad_cam = GradCAM(model, model.layer4[-1], 224) cam = grad_cam.calculate_cam(input_tensor) GradCAM.show_cam_on_image(image, cam)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。