  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. warnings.simplefilter('ignore')
  4. import torch, yaml, cv2, os, shutil
  5. import numpy as np
  6. np.random.seed(0)
  7. import matplotlib.pyplot as plt
  8. from tqdm import trange
  9. from PIL import Image
  10. from models.yolo import Model
  11. from utils.general import intersect_dicts
  12. from utils.augmentations import letterbox
  13. from utils.general import xywh2xyxy
  14. from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
  15. from pytorch_grad_cam.utils.image import show_cam_on_image
  16. from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
  17. class yolov5_heatmap:
  18. def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):
  19. device = torch.device(device)
  20. ckpt = torch.load(weight)
  21. model_names = ckpt['model'].names
  22. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  23. model = Model(cfg, ch=3, nc=len(model_names)).to(device)
  24. csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor']) # intersect
  25. model.load_state_dict(csd, strict=False) # load
  26. model.eval()
  27. print(f'Transferred {len(csd)}/{len(model.state_dict())} items')
  28. target_layers = [eval(layer)]
  29. method = eval(method)
  30. colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int)
  31. self.__dict__.update(locals())
  32. def post_process(self, result):
  33. logits_ = result[..., 4:]
  34. boxes_ = result[..., :4]
  35. sorted, indices = torch.sort(logits_[..., 0], descending=True)
  36. return logits_[0][indices[0]], xywh2xyxy(boxes_[0][indices[0]]).cpu().detach().numpy()
  37. def draw_detections(self, box, color, name, img):
  38. xmin, ymin, xmax, ymax = list(map(int, list(box)))
  39. cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)
  40. cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)
  41. return img
  42. def __call__(self, img_path, save_path):
  43. # remove dir if exist
  44. if os.path.exists(save_path):
  45. shutil.rmtree(save_path)
  46. # make dir if not exist
  47. os.makedirs(save_path, exist_ok=True)
  48. # img process
  49. img = cv2.imread(img_path)
  50. img = letterbox(img)[0]
  51. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  52. img = np.float32(img) / 255.0
  53. tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)
  54. # init ActivationsAndGradients
  55. grads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)
  56. # get ActivationsAndResult
  57. result = grads(tensor)
  58. activations = grads.activations[0].cpu().detach().numpy()
  59. # postprocess to yolo output
  60. post_result, post_boxes = self.post_process(result[0])
  61. for i in trange(int(post_result.size(0) * self.ratio)):
  62. if post_result[i][0] < self.conf_threshold:
  63. break
  64. self.model.zero_grad()
  65. if self.backward_type == 'conf':
  66. post_result[i, 0].backward(retain_graph=True)
  67. else:
  68. # get max probability for this prediction
  69. score = post_result[i, 1:].max()
  70. score.backward(retain_graph=True)
  71. # process heatmap
  72. gradients = grads.gradients[0]
  73. b, k, u, v = gradients.size()
  74. weights = self.method.get_cam_weights(self.method, None, None, None, activations, gradients.detach().numpy())
  75. weights = weights.reshape((b, k, 1, 1))
  76. saliency_map = np.sum(weights * activations, axis=1)
  77. saliency_map = np.squeeze(np.maximum(saliency_map, 0))
  78. saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))
  79. saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
  80. if (saliency_map_max - saliency_map_min) == 0:
  81. continue
  82. saliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)
  83. # add heatmap and box to image
  84. cam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)
  85. cam_image = self.draw_detections(post_boxes[i], self.colors[int(post_result[i, 1:].argmax())], f'{self.model_names[int(post_result[i, 1:].argmax())]} {post_result[i][0]:.2f}', cam_image)
  86. cam_image = Image.fromarray(cam_image)
  87. cam_image.save(f'{save_path}/{i}.png')
  88. def get_params():
  89. params = {
  90. 'weight': 'runs/train/exp/weights/best.pt',
  91. 'cfg': 'models/yolov5m.yaml',
  92. 'device': 'cuda:0',
  93. 'method': 'XGradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM
  94. 'layer': 'model.model[-2]',
  95. 'backward_type': 'class', # class or conf
  96. 'conf_threshold': 0.6, # 0.6
  97. 'ratio': 0.02 # 0.02-0.1
  98. }
  99. return params
  100. if __name__ == '__main__':
  101. model = yolov5_heatmap(**get_params())
  102. model(r'dataset\images\test\aircraft_1064.jpg', 'result')
  1. 需要安装pytorch_grad_cam库,可以直接pip install pytorch-grad-cam或者去 jacobgil/pytorch-grad-cam将源码下载下来,只需要下载这一个文件夹就可以,放入项目中,

  1. get_params中的参数:

  1. weight:模型权重文件

  1. cfg:模型文件

  1. device:选择使用GPU还是CPU

  1. method:选择grad-cam方法,这里是提供了几种,可能对效果有点不一样,可以都尝试一下

  1. layer: 选择需要可视化的那层

  1. backward_type:反向传播的方式,可以是以conf的loss传播,也可以class的loss传播

  1. conf_threshold置信度

  1. ratio 就是一个参数,用来筛选置信度高的结果,低的就舍弃




将model.eval()改为 model.fuse().eval()

2.inplace 出错




