当前位置:   article > 正文

热力图,批量处理_class yolov7_heatmap: def __init__(self, weight, c

class yolov7_heatmap: def __init__(self, weight, cfg, device, method, layer,

本帖主要是为了记录学习魔傀面具作者的热力图代码,YOLOV8-gradcam 热力图可视化 即插即用 不需要对源码做任何修改!_魔鬼面具的博客-CSDN博客

并将其适配批量处理,代码仅供参考,多海涵。

  1. import warnings
  2. warnings.filterwarnings('ignore')
  3. warnings.simplefilter('ignore')
  4. import torch, yaml, cv2, os, shutil
  5. import numpy as np
  6. np.random.seed(0)
  7. import matplotlib.pyplot as plt
  8. from tqdm import trange
  9. from PIL import Image
  10. from models.yolo import Model
  11. from utils.general import intersect_dicts
  12. from utils.augmentations import letterbox
  13. from utils.general import xywh2xyxy
  14. from pytorch_grad_cam import GradCAMPlusPlus, GradCAM, XGradCAM
  15. from pytorch_grad_cam.utils.image import show_cam_on_image
  16. from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
  17. class yolov5_heatmap:
  18. def __init__(self, weight, cfg, device, method, layer, backward_type, conf_threshold, ratio):
  19. device = torch.device(device)
  20. ckpt = torch.load(weight)
  21. model_names = ckpt['model'].names
  22. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  23. model = Model(cfg, ch=3, nc=len(model_names)).to(device)
  24. csd = intersect_dicts(csd, model.state_dict(), exclude=['anchor']) # intersect
  25. model.load_state_dict(csd, strict=False) # load
  26. model.eval()
  27. print(f'Transferred {len(csd)}/{len(model.state_dict())} items')
  28. target_layers = [eval(layer)]
  29. method = eval(method)
  30. colors = np.random.uniform(0, 255, size=(len(model_names), 3)).astype(np.int)
  31. self.__dict__.update(locals())
  32. def post_process(self, result):
  33. logits_ = result[..., 4:]
  34. boxes_ = result[..., :4]
  35. sorted, indices = torch.sort(logits_[..., 0], descending=True)
  36. return logits_[0][indices[0]], xywh2xyxy(boxes_[0][indices[0]]).cpu().detach().numpy()
  37. """def draw_detections(self, box, color, name, img):
  38. xmin, ymin, xmax, ymax = list(map(int, list(box)))
  39. cv2.rectangle(img, (xmin, ymin), (xmax, ymax), tuple(int(x) for x in color), 2)
  40. cv2.putText(img, str(name), (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.8, tuple(int(x) for x in color), 2, lineType=cv2.LINE_AA)
  41. return img"""
  42. #此处控制是否显示预测框
  43. def __call__(self, img_path, save_path):
  44. # remove dir if exist
  45. if os.path.exists(save_path):
  46. shutil.rmtree(save_path)
  47. # make dir if not exist
  48. os.makedirs(save_path, exist_ok=True)
  49. # img process
  50. img = cv2.imread(img_path)
  51. img = letterbox(img)[0]
  52. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  53. img = np.float32(img) / 255.0
  54. tensor = torch.from_numpy(np.transpose(img, axes=[2, 0, 1])).unsqueeze(0).to(self.device)
  55. # init ActivationsAndGradients
  56. grads = ActivationsAndGradients(self.model, self.target_layers, reshape_transform=None)
  57. # get ActivationsAndResult
  58. result = grads(tensor)
  59. activations = grads.activations[0].cpu().detach().numpy()
  60. # postprocess to yolo output
  61. post_result, post_boxes = self.post_process(result[0])
  62. for i in trange(int(post_result.size(0) * self.ratio)):
  63. if post_result[i][0] < self.conf_threshold:
  64. break
  65. self.model.zero_grad()
  66. if self.backward_type == 'conf':
  67. post_result[i, 0].backward(retain_graph=True)
  68. else:
  69. # get max probability for this prediction
  70. score = post_result[i, 1:].max()
  71. score.backward(retain_graph=True)
  72. # process heatmap
  73. gradients = grads.gradients[0]
  74. b, k, u, v = gradients.size()
  75. weights = self.method.get_cam_weights(self.method, None, None, None, activations, gradients.detach().numpy())
  76. weights = weights.reshape((b, k, 1, 1))
  77. saliency_map = np.sum(weights * activations, axis=1)
  78. saliency_map = np.squeeze(np.maximum(saliency_map, 0))
  79. saliency_map = cv2.resize(saliency_map, (tensor.size(3), tensor.size(2)))
  80. saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
  81. if (saliency_map_max - saliency_map_min) == 0:
  82. continue
  83. saliency_map = (saliency_map - saliency_map_min) / (saliency_map_max - saliency_map_min)
  84. # add heatmap and box to image
  85. cam_image = show_cam_on_image(img.copy(), saliency_map, use_rgb=True)
  86. """cam_image = self.draw_detections(post_boxes[i], self.colors[int(post_result[i, 1:].argmax())], f'{self.model_names[int(post_result[i, 1:].argmax())]} {post_result[i][0]:.2f}', cam_image)"""
  87. cam_image = Image.fromarray(cam_image)
  88. cam_image.save(f'{save_path}/{i}.png')
  89. def get_params():
  90. params = {
  91. 'weight': r'权重绝对地址',
  92. 'cfg': r'模型绝对地址',
  93. 'device': 'cuda:0',
  94. 'method': 'GradCAM', # GradCAMPlusPlus, GradCAM, XGradCAM
  95. 'layer': 'model.model[-2]',
  96. 'backward_type': 'class', # class or conf
  97. 'conf_threshold': 0.6, # 0.6
  98. 'ratio': 0.02 # 0.02-0.1
  99. }
  100. return params
  101. ################# 要使用热力图需要将YOLO.py文件中所有的inplace设置为False,默认是True #############################################
  102. if __name__ == '__main__':
  103. model = yolov5_heatmap(**get_params())
  104. #model(r'图片绝对地址', 'result')#单张处理
  105. path=r"图片绝对地址"
  106. print(path)
  107. path1 = os.listdir(path)
  108. for i1 in path1:
  109. i2="result/"+str(i1)
  110. i1=os.path.join(path,i1)
  111. model(i1,i2)#多张处理

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/135190
推荐阅读
相关标签
  

闽ICP备14008679号