当前位置:   article > 正文

yolov8的predict使用方法,更改predict.py的输出结果,输出label的真实坐标,保存图片和txt文档,图片中没有异物生成空的txt文档_yolov8 predict

yolov8 predict

Yolov8的predict的使用方法:

新建predict.py如下:

  1. from ultralytics import YOLO
  2. model = YOLO("你训练好的模型.pt")
  3. model.predict(source="datasets/images/val",save=True,save_conf=True,save_txt=True,name='output')
  4. #source后为要预测的图片数据集的的路径
  5. #save=True为保存预测结果
  6. #save_conf=True为保存坐标信息
  7. #save_txt=True为保存txt结果,但是yolov8本身当图片中预测不到异物时,不产生txt文件

默认predict的输出坐标为xywh格式,即中心点坐标和预测的框的宽高,将其改为真实坐标方法如下:在ultralytics\engine\results.py中,将def save_txt()置换为如下代码:

  1. def save_txt(self, txt_file, save_conf=False):
  2. """
  3. Save predictions into txt file.
  4. Args:
  5. txt_file (str): txt file path.
  6. save_conf (bool): save confidence score or not.
  7. """
  8. boxes = self.boxes
  9. masks = self.masks
  10. probs = self.probs
  11. kpts = self.keypoints
  12. texts = []
  13. if probs is not None:
  14. # Classify
  15. [texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
  16. elif boxes:
  17. # Detect/segment/pose
  18. for j, d in enumerate(boxes):
  19. c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
  20. #line = (c, *d.xywhn.view(-1))
  21. line = ( c,conf, *d.xyxy.view(-1)) #重点在这里,还可以通过这里改变txt中信息#的顺序
  22. if masks:
  23. seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
  24. line = (c, *seg)
  25. if kpts is not None:
  26. kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
  27. line += (*kpt.reshape(-1).tolist(), )
  28. line += (conf, ) * save_conf + (() if id is None else (id, ))
  29. line = line[:-1]
  30. texts.append('%s %.6f %d %d %d %d' % (line[0], float(line[1]), int(line[2]), int(line[3]), int(line[4]), int(line[5])))
  31. if texts:
  32. Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
  33. with open(txt_file, 'a') as f:
  34. f.writelines(text + '\n' for text in texts)

要实现图片中没有异物生成空的txt文档,采用如下方法:

在ultralytics\engine\predictor.py中,找到def stream_inference(),将其置换为如下代码:

  1. def stream_inference(self, source=None, model=None, *args, **kwargs):
  2. """Streams real-time inference on camera feed and saves results to file."""
  3. if self.args.verbose:
  4. LOGGER.info('')
  5. # Setup model
  6. if not self.model:
  7. self.setup_model(model)
  8. # Setup source every time predict is called
  9. self.setup_source(source if source is not None else self.args.source)
  10. # Check if save_dir/ label file exists
  11. if self.args.save or self.args.save_txt:
  12. (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
  13. # Warmup model
  14. if not self.done_warmup:
  15. self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
  16. self.done_warmup = True
  17. self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
  18. self.run_callbacks('on_predict_start')
  19. for batch in self.dataset:
  20. self.run_callbacks('on_predict_batch_start')
  21. self.batch = batch
  22. path, im0s, vid_cap, s = batch
  23. # Preprocess
  24. with profilers[0]:
  25. im = self.preprocess(im0s)
  26. # Inference
  27. with profilers[1]:
  28. preds = self.inference(im, *args, **kwargs)
  29. # Postprocess
  30. with profilers[2]:
  31. self.results = self.postprocess(preds, im, im0s)
  32. self.run_callbacks('on_predict_postprocess_end')
  33. # Visualize, save, write results
  34. n = len(im0s)
  35. for i in range(n):
  36. self.seen += 1
  37. self.results[i].speed = {
  38. 'preprocess': profilers[0].dt * 1E3 / n,
  39. 'inference': profilers[1].dt * 1E3 / n,
  40. 'postprocess': profilers[2].dt * 1E3 / n}
  41. p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
  42. p = Path(p)
  43. if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
  44. s += self.write_results(i, self.results, (p, im, im0))
  45. if self.args.save or self.args.save_txt:
  46. self.results[i].save_dir = self.save_dir.__str__()
  47. if self.args.show and self.plotted_img is not None:
  48. self.show(p)
  49. if self.args.save and self.plotted_img is not None:
  50. self.save_preds(vid_cap, i, str(self.save_dir / p.name))
  51. self.run_callbacks('on_predict_batch_end')
  52. yield from self.results
  53. # Print time (inference-only)
  54. if self.args.verbose:
  55. LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
  56. # Release assets
  57. if isinstance(self.vid_writer[-1], cv2.VideoWriter):
  58. self.vid_writer[-1].release() # release final video writer
  59. # Print results
  60. if self.args.verbose and self.seen:
  61. t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
  62. LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
  63. f'{(1, 3, *im.shape[2:])}' % t)
  64. if self.args.save or self.args.save_txt or self.args.save_crop:
  65. nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
  66. s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
  67. m = f"\n{nl} image{'s' * (nl > 1)} saved to {self.save_dir / 'images'}" if self.args.save_txt else ''
  68. LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
  69. # 原始文件夹路径和目标文件夹路径
  70. source_folder = self.save_dir
  71. target_folder = os.path.join(source_folder, 'images')
  72. target_label_folder = os.path.join(source_folder, 'labels')
  73. # 创建目标文件夹
  74. os.makedirs(target_folder, exist_ok=True)
  75. # 获取原始文件夹中的所有文件
  76. files = os.listdir(source_folder)
  77. # 遍历所有文件
  78. for file in files:
  79. # 构建文件的绝对路径
  80. file_path = os.path.join(source_folder, file)
  81. # 检查文件是否是图片文件
  82. if os.path.isfile(file_path) and file.lower().endswith(('.jpg', '.jpeg', '.png')):
  83. # 构建目标文件的路径
  84. target_file_path = os.path.join(target_folder, file)
  85. # 将图片文件复制到目标文件夹
  86. shutil.copy(file_path, target_file_path)
  87. # 删除原始文件夹中的图片文件
  88. os.remove(file_path)
  89. image_folder = target_folder
  90. txt_folder = target_label_folder
  91. image_files = [file for file in os.listdir(image_folder) if file.lower().endswith(('.jpg', '.jpeg', '.png'))]
  92. # 遍历图片文件夹中的图片文件
  93. for image_file in image_files:
  94. # 构建图片和txt文件的路径
  95. image_file_path = os.path.join(image_folder, image_file)
  96. txt_file_path = os.path.join(txt_folder, os.path.splitext(image_file)[0] + '.txt')
  97. # 检查txt文件是否已存在
  98. if os.path.isfile(txt_file_path):
  99. #print(f"txt文件'{txt_file_path}'已存在,不执行操作。")
  100. pass
  101. else:
  102. # 创建空白的txt文件
  103. with open(txt_file_path, 'w') as txt_file:
  104. pass
  105. self.run_callbacks('on_predict_end')

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号