当前位置:   article > 正文

yolov7 mask训练笔记_yolo mask

yolo mask

目录

训练过程:

不要过早停止训练:

yaml文件解析

关于预训练:

dataloader

coco128-seg.yaml

segments 格式:

解析json格式标签:

统计points个数:

coco合并标签分离为独立文件:

coco标签转换成独立标签线性插值:

导出onnx

onnx结构:

yolov7 mask python训练,tensorrt推理框架

python对比原图和结果:


训练过程:

https://github.com/chelsea456/yolov7_mask/tree/main/yolov7_mask

不要过早停止训练:

parser.add_argument('--patience', type=int, default=0, help='EarlyStopping patience (epochs without improvement)')

yaml文件解析

3、backbone(骨干网络*)
# yolov7 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [32, 3, 1]],  # 0
  
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2      
   [-1, 1, Conv, [64, 3, 1]],
   
   [-1, 1, Conv, [128, 3, 2]],  # 3-P2/4  
   [-1, 1, Conv, [64, 1, 1]],
   [-2, 1, Conv, [64, 1, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1]],  # 11
         
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3], 1, Concat, [1]],  # 16-P3/8  
   [-1, 1, Conv, [128, 1, 1]],
   [-2, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [512, 1, 1]],  # 24
         
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, -3], 1, Concat, [1]],  # 29-P4/16  
   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [1024, 1, 1]],  # 37
         
   [-1, 1, MP, []],
   [-1, 1, Conv, [512, 1, 1]],
   [-3, 1, Conv, [512, 1, 1]],
   [-1, 1, Conv, [512, 3, 2]],
   [[-1, -3], 1, Concat, [1]],  # 42-P5/32  
   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [1024, 1, 1]],  # 50
  ]
from 表示该层的输入从哪来。-1表示输入取自上一层,-2表示上两层,3表示第3层(从0开始数),[-1, 4]表示取自上一层和第4层,依次类推。。。。。
number 表示该层模块堆叠的次数,对于C3、BottleneckCSP等模块,表示其子模块的堆叠,具体细节可以查看源代码。当然最终的次数还要乘上depth_multiple系数。
module 表示该层的模块是什么类型。Conv就是卷积+BN+激活模块。所有的模块在 model/common.py 中都有定义。
args 表示输入到模块的参数。例如Conv:[128, 3, 2] 表示输出通道128,卷积核尺寸3,strid=2,当然最终的输出通道数还要乘上 width_multiple,对于其他模块,第一个参数值一般都是指输出通道数,具体细节可以看 model/common.py 中的定义。
 
原文链接:https://blog.csdn.net/weixin_43397302/article/details/126708227

关于预训练

提高了3个预训练:

yolov5s-seg.pt

yolov7-seg.pt

yolov7x-seg.pt

python segment/train.py --data coco.yaml --batch 16 --weights '' --cfg yolov7-seg.yaml --epochs 300 --name yolov7-seg --img 640 --hyp hyp.scratch-high.yaml

dataloader

labels, shapes, self.segments = zip(*cache.values())

image_detect/coco128-seg.yaml at master · HoyoenKim/image_detect · GitHub

coco128-seg.yaml

download: https://ultralytics.com/assets/coco128-seg.zip

  1. # YOLOv5 by Ultralytics, GPL-3.0 license
  2. # COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
  3. # Example usage: python train.py --data coco128.yaml
  4. # parent
  5. # ├── yolov5
  6. # └── datasets
  7. # └── coco128-seg ← downloads here (7 MB)
  8. # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
  9. path: ../datasets/coco128-seg # dataset root dir
  10. train: images/train2017 # train images (relative to 'path') 128 images
  11. val: images/train2017 # val images (relative to 'path') 128 images
  12. test: # test images (optional)
  13. # Classes
  14. names:
  15. 0: person
  16. 1: bicycle
  17. 2: car
  18. 3: motorcycle
  19. 4: airplane
  20. 5: bus
  21. 6: train
  22. 7: truck
  23. 8: boat
  24. 9: traffic light
  25. 10: fire hydrant
  26. 11: stop sign
  27. 12: parking meter
  28. 13: bench
  29. 14: bird
  30. 15: cat
  31. 16: dog
  32. 17: horse
  33. 18: sheep
  34. 19: cow
  35. 20: elephant
  36. 21: bear
  37. 22: zebra
  38. 23: giraffe
  39. 24: backpack
  40. 25: umbrella
  41. 26: handbag
  42. 27: tie
  43. 28: suitcase
  44. 29: frisbee
  45. 30: skis
  46. 31: snowboard
  47. 32: sports ball
  48. 33: kite
  49. 34: baseball bat
  50. 35: baseball glove
  51. 36: skateboard
  52. 37: surfboard
  53. 38: tennis racket
  54. 39: bottle
  55. 40: wine glass
  56. 41: cup
  57. 42: fork
  58. 43: knife
  59. 44: spoon
  60. 45: bowl
  61. 46: banana
  62. 47: apple
  63. 48: sandwich
  64. 49: orange
  65. 50: broccoli
  66. 51: carrot
  67. 52: hot dog
  68. 53: pizza
  69. 54: donut
  70. 55: cake
  71. 56: chair
  72. 57: couch
  73. 58: potted plant
  74. 59: bed
  75. 60: dining table
  76. 61: toilet
  77. 62: tv
  78. 63: laptop
  79. 64: mouse
  80. 65: remote
  81. 66: keyboard
  82. 67: cell phone
  83. 68: microwave
  84. 69: oven
  85. 70: toaster
  86. 71: sink
  87. 72: refrigerator
  88. 73: book
  89. 74: clock
  90. 75: vase
  91. 76: scissors
  92. 77: teddy bear
  93. 78: hair drier
  94. 79: toothbrush
  95. # Download script/URL (optional)
  96. download: https://ultralytics.com/assets/coco128-seg.zip

segments 格式:

segments = [xyn2xy(x, w, h, padw, padh) for x in segments]

segments2boxes

  1. def segments2boxes(segments):
  2. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  3. boxes = []
  4. for s in segments:
  5. x, y = s.T # segment xy
  6. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  7. return xyxy2xywh(np.array(boxes)) # cls, xywh

解析json格式标签:

  1. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  2. # Cache dataset labels, check images and read shapes
  3. x = {} # dict
  4. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  5. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  6. if self.data_type=="json":
  7. pbar = tqdm(zip(self.im_files, self.label_files), desc='Scanning images', total=len(self.im_files))
  8. for i, (im_file, lb_file) in enumerate(pbar):
  9. try:
  10. annotations = np.zeros((0, 5), dtype=np.float32)
  11. shape = np.zeros((2,), dtype=np.int32)
  12. label = []
  13. segments=[]
  14. if lb_file.endswith(".json"):
  15. json_file = json.load(open(lb_file, "r", encoding="utf-8"))
  16. imageHeight = json_file['imageHeight']
  17. imageWidth = json_file['imageWidth']
  18. shape[0] = imageWidth
  19. shape[1] = imageHeight
  20. for multi in json_file["shapes"]:
  21. points = np.array(multi["points"])
  22. xmin = (min(points[:, 0]) if min(points[:, 0]) > 0 else 0) / imageWidth
  23. xmax = (max(points[:, 0]) if max(points[:, 0]) > 0 else 0) / imageWidth
  24. ymin = (min(points[:, 1]) if min(points[:, 1]) > 0 else 0) / imageHeight
  25. ymax = (max(points[:, 1]) if max(points[:, 1]) > 0 else 0) / imageHeight
  26. label = multi["label"]
  27. if xmax > xmin and ymax > ymin:
  28. annotation = np.zeros((1, 5), dtype=np.float32)
  29. annotation[0, 1] = (xmin + xmax) / 2
  30. annotation[0, 2] = (ymin + ymax) / 2
  31. annotation[0, 3] = xmax - xmin
  32. annotation[0, 4] = ymax - ymin
  33. # cls
  34. annotation[0, 0] = self.hyp['names'].index(label.lower().strip())
  35. # annotation[0, 0] = _class_to_ind[label.lower().strip()]
  36. # annotations = np.append(annotations, annotation, axis=0)
  37. annotations = np.row_stack((annotations, annotation))
  38. points = points.astype(np.float32)
  39. points[:, 0] = points[:, 0] / imageWidth
  40. points[:, 1] = points[:, 1] / imageHeight
  41. segments.append(points)
  42. if len(annotations) < 1:
  43. nm += 1
  44. print("json no obj------------")
  45. annotations = np.zeros((0, 5), dtype=np.float32)
  46. else:
  47. nf += 1
  48. x[im_file] = [annotations, shape, segments]
  49. except Exception as e:
  50. nc += 1
  51. print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
  52. pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
  53. f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  54. pbar.close()
  55. if nf == 0:
  56. print(f'{prefix}WARNING: No labels found in {path}.')

统计points个数:

  1. import glob
  2. import json
  3. import numpy as np
  4. if __name__ == '__main__':
  5. label_dir=r'F:\project\detect\yolov7\yolov7_mask\Y\train'
  6. lb_files= glob.glob(label_dir+ '/*.json', recursive=True) # f = list(p.rglob('*.*')) # pathlib
  7. annotations = np.zeros((0, 5), dtype=np.float32)
  8. shape = np.zeros((2,), dtype=np.int32)
  9. label = []
  10. for lb_file in lb_files:
  11. if lb_file.endswith(".json"):
  12. json_file = json.load(open(lb_file, "r", encoding="utf-8"))
  13. imageHeight = json_file['imageHeight']
  14. imageWidth = json_file['imageWidth']
  15. shape[0] = imageWidth
  16. shape[1] = imageHeight
  17. for multi in json_file["shapes"]:
  18. points = np.array(multi["points"])
  19. xmin = (min(points[:, 0]) if min(points[:, 0]) > 0 else 0) / imageWidth
  20. xmax = (max(points[:, 0]) if max(points[:, 0]) > 0 else 0) / imageWidth
  21. ymin = (min(points[:, 1]) if min(points[:, 1]) > 0 else 0) / imageHeight
  22. ymax = (max(points[:, 1]) if max(points[:, 1]) > 0 else 0) / imageHeight
  23. label = multi["label"]
  24. if xmax > xmin and ymax > ymin:
  25. print(lb_file,len(points))
  26. # points = points.astype(np.float32)

coco合并标签分离为独立文件:

  1. import json
  2. import os, cv2
  3. import numpy as np
  4. def visualization_bbox1(json_path, img_path): # 需要画的第num副图片, 对应的json路径和图片路径
  5. with open(json_path ,encoding='utf-8') as annos:
  6. annotation_json = json.load(annos)
  7. print('num_key is:', len(annotation_json) ,'json key is:', annotation_json.keys()) # 读出json文件的关键字
  8. print('json num_images is:', len(annotation_json['images'])) # json文件中包含的图片数量
  9. for img_i in range(len(annotation_json['images'])):
  10. image_name = annotation_json['images'][img_i]['file_name'] # 读取图片名
  11. id = annotation_json['images'][img_i]['id'] # 读取图片id
  12. image_path = os.path.join(img_path, str(image_name).zfill(5)) # 拼接图像路径
  13. image = cv2.imread(image_path, 1) # 保持原始格式的方式读取图像
  14. num_bbox = 0 # 统计一幅图片中bbox的数量
  15. coco_train = dict()
  16. coco_train['flags'] = {}
  17. coco_train['imagePath'] = os.path.basename(image_path)
  18. coco_train['shapes'] = []
  19. coco_train['imageData'] = None
  20. coco_train['imageHeight'] = image.shape[0]
  21. coco_train['imageWidth'] = image.shape[1]
  22. for i in range(len(annotation_json['annotations'][::])):
  23. if annotation_json['annotations'][i - 1]['image_id'] == id:
  24. num_bbox = num_bbox + 1
  25. box_dict={}
  26. box_dict["label"]="quexian"
  27. box_dict["shape_type"]="polygon"
  28. box_dict["points"]=[]
  29. x, y, w, h = annotation_json['annotations'][i - 1]['bbox'] # 读取边框
  30. image = cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (0, 255, 255), 2)
  31. points = annotation_json['annotations'][i - 1]['segmentation'] # keypoints
  32. data_len =len(points[0] )//2
  33. for index in range(data_len):
  34. cv2.circle(image, (int(points[0][index *2]) ,int(points[0][index * 2 +1])), 3, (0, 0, 213), -1) # x,y,r,color
  35. box_dict["points"].append((int(points[0][index *2]) ,int(points[0][index * 2 +1])))
  36. coco_train['shapes'].append(box_dict)
  37. # cv2.imwrite(f"gt/{img_i}.jpg" ,image)
  38. # cv2.resizeWindow("image_name", 2500, 1250) # 创建500*500的窗口
  39. if is_show:
  40. print(os.path.basename(image_path))
  41. cv2.namedWindow("image_name", 0) # 创建窗口
  42. cv2.imshow("image_name", image)
  43. cv2.waitKey(0)
  44. else:
  45. train_file=image_path[:-4]+".json"
  46. with open(train_file, 'w') as write_f:
  47. write_f.write(json.dumps(coco_train, indent=2, ensure_ascii=False))
  48. if __name__ == "__main__":
  49. is_show=True
  50. os.makedirs("gt" ,exist_ok=True)
  51. train_json = r'D:\work\lbg\fenge\data/mark.json'
  52. train_path = r'D:\work\lbg\fenge\data\trainImage'
  53. visualization_bbox1(train_json, train_path)

coco标签转换成独立标签线性插值

  1. import json
  2. import os, cv2
  3. import numpy as np
  4. from scipy.interpolate import interp1d
  5. def visualization_bbox1(json_path, img_path): # 需要画的第num副图片, 对应的json路径和图片路径
  6. with open(json_path ,encoding='utf-8') as annos:
  7. annotation_json = json.load(annos)
  8. print('num_key is:', len(annotation_json) ,'json key is:', annotation_json.keys()) # 读出json文件的关键字
  9. print('json num_images is:', len(annotation_json['images'])) # json文件中包含的图片数量
  10. for img_i in range(len(annotation_json['images'])):
  11. image_name = annotation_json['images'][img_i]['file_name'] # 读取图片名
  12. id = annotation_json['images'][img_i]['id'] # 读取图片id
  13. image_path = os.path.join(img_path, str(image_name).zfill(5)) # 拼接图像路径
  14. image = cv2.imread(image_path, 1) # 保持原始格式的方式读取图像
  15. num_bbox = 0 # 统计一幅图片中bbox的数量
  16. coco_train = dict()
  17. coco_train['flags'] = {}
  18. coco_train['imagePath'] = os.path.basename(image_path)
  19. coco_train['shapes'] = []
  20. coco_train['imageData'] = None
  21. coco_train['imageHeight'] = image.shape[0]
  22. coco_train['imageWidth'] = image.shape[1]
  23. for i in range(len(annotation_json['annotations'][::])):
  24. if annotation_json['annotations'][i - 1]['image_id'] == id:
  25. num_bbox = num_bbox + 1
  26. box_dict={}
  27. box_dict["label"]="quexian"
  28. box_dict["shape_type"]="polygon"
  29. box_dict["points"]=[]
  30. x, y, w, h = annotation_json['annotations'][i - 1]['bbox'] # 读取边框
  31. image = cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (0, 255, 255), 2)
  32. points = annotation_json['annotations'][i - 1]['segmentation'] # keypoints
  33. data_len =len(points[0] )//2
  34. print('points len is:', data_len)
  35. data_o=[]
  36. y_row=[]
  37. for index in range(data_len):
  38. data_o.append((points[0][index *2],points[0][index * 2 +1]))
  39. cv2.circle(image, (int(points[0][index *2]) ,int(points[0][index * 2 +1])), 3, (0, 0, 213), -1) # x,y,r,color
  40. # box_dict["points"].append((int(points[0][index *2]) ,int(points[0][index * 2 +1])))
  41. coco_train['shapes'].append(box_dict)
  42. data_o=np.asarray(data_o)
  43. data_o=data_o[np.lexsort(data_o[:, ::-1].T)]
  44. for index, data_x in enumerate(data_o):
  45. if index==0:
  46. continue
  47. if data_x[0]<=data_o[index-1][0]:
  48. data_x[0] = data_o[index-1][0]+0.01
  49. # f1 = interp1d(data_o[:,0], data_o[:,1], kind='cubic')
  50. f1 = interp1d(data_o[:,0], data_o[:,1], kind='linear')
  51. x_pred = np.linspace(data_o[0][0], data_o[-1][0], num=155)
  52. y1 = f1(x_pred)
  53. for index,x_data in enumerate(x_pred):
  54. cv2.circle(image, (int(x_data), int(y1[index])), 2, (255, 0, 0), -1) # x,y,r,color
  55. box_dict["points"].append((int(x_data), int(y1[index])))
  56. # cv2.imwrite(f"gt/{img_i}.jpg" ,image)
  57. cv2.namedWindow("image_name", 0) # 创建窗口
  58. # cv2.resizeWindow("image_name", 2500, 1250) # 创建500*500的窗口
  59. cv2.imshow("image_name", image)
  60. cv2.waitKey(1)
  61. train_file=image_path[:-4]+".json"
  62. with open(train_file, 'w') as write_f:
  63. write_f.write(json.dumps(coco_train, indent=2, ensure_ascii=False))
  64. if __name__ == "__main__":
  65. os.makedirs("gt" ,exist_ok=True)
  66. train_json = r'D:\work\lbg\fenge\data/mark.json'
  67. train_path = r'D:\work\lbg\fenge\data\trainImage'
  68. visualization_bbox1(train_json, train_path)

导出onnx

  1. import argparse
  2. import json
  3. import os
  4. import platform
  5. import subprocess
  6. import sys
  7. import time
  8. import warnings
  9. from pathlib import Path
  10. import pandas as pd
  11. import torch
  12. import yaml
  13. from torch.utils.mobile_optimizer import optimize_for_mobile
  14. FILE = Path(__file__).resolve()
  15. ROOT = FILE.parents[0] # YOLOv5 root directory
  16. if str(ROOT) not in sys.path:
  17. sys.path.append(str(ROOT)) # add ROOT to PATH
  18. if platform.system() != 'Windows':
  19. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  20. from models.experimental import attempt_load
  21. from models.yolo import Detect
  22. from utils.dataloaders import LoadImages
  23. from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
  24. check_yaml, colorstr, file_size, get_default_args, print_args, url2file)
  25. from utils.torch_utils import select_device, smart_inference_mode
  26. def export_formats():
  27. # YOLOv5 export formats
  28. x = [['PyTorch', '-', '.pt', True, True], ['TorchScript', 'torchscript', '.torchscript', True, True],
  29. ['ONNX', 'onnx', '.onnx', True, True], ['OpenVINO', 'openvino', '_openvino_model', True, False],
  30. ['TensorRT', 'engine', '.engine', False, True], ['CoreML', 'coreml', '.mlmodel', True, False],
  31. ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
  32. ['TensorFlow GraphDef', 'pb', '.pb', True, True], ['TensorFlow Lite', 'tflite', '.tflite', True, False],
  33. ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
  34. ['TensorFlow.js', 'tfjs', '_web_model', False, False], ]
  35. return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
  36. def try_export(inner_func):
  37. # YOLOv5 export decorator, i..e @try_export
  38. inner_args = get_default_args(inner_func)
  39. def outer_func(*args, **kwargs):
  40. prefix = inner_args['prefix']
  41. try:
  42. with Profile() as dt:
  43. f, model = inner_func(*args, **kwargs)
  44. LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
  45. return f, model
  46. except Exception as e:
  47. LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
  48. return None, None
  49. return outer_func
  50. @try_export
  51. def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
  52. # YOLOv5 ONNX export
  53. check_requirements(('onnx',))
  54. import onnx
  55. LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
  56. f = file.with_suffix('.onnx')
  57. torch.onnx.export(model.cpu() if dynamic else model, # --dynamic only compatible with cpu
  58. im.cpu() if dynamic else im, f, verbose=False, opset_version=opset,
  59. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  60. do_constant_folding=not train, input_names=['images'], output_names=['output'],
  61. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  62. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  63. } if dynamic else None)
  64. # Checks
  65. model_onnx = onnx.load(f) # load onnx model
  66. onnx.checker.check_model(model_onnx) # check onnx model
  67. # Metadata
  68. d = {'stride': int(max(model.stride)), 'names': model.names}
  69. for k, v in d.items():
  70. meta = model_onnx.metadata_props.add()
  71. meta.key, meta.value = k, str(v)
  72. onnx.save(model_onnx, f)
  73. # Simplify
  74. if simplify:
  75. try:
  76. cuda = torch.cuda.is_available()
  77. check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
  78. import onnxsim
  79. LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
  80. model_onnx, check = onnxsim.simplify(model_onnx)
  81. assert check, 'assert check failed'
  82. onnx.save(model_onnx, f)
  83. except Exception as e:
  84. LOGGER.info(f'{prefix} simplifier failure: {e}')
  85. return f, model_onnx
  86. @try_export
  87. def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
  88. # YOLOv5 OpenVINO export
  89. check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
  90. import openvino.inference_engine as ie
  91. LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
  92. f = str(file).replace('.pt', f'_openvino_model{os.sep}')
  93. cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
  94. subprocess.check_output(cmd.split()) # export
  95. with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
  96. yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
  97. return f, None
  98. @try_export
  99. def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
  100. # YOLOv5 CoreML export
  101. check_requirements(('coremltools',))
  102. import coremltools as ct
  103. LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
  104. f = file.with_suffix('.mlmodel')
  105. ts = torch.jit.trace(model, im, strict=False) # TorchScript model
  106. ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
  107. bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
  108. if bits < 32:
  109. if platform.system() == 'Darwin': # quantization only supported on macOS
  110. with warnings.catch_warnings():
  111. warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
  112. ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
  113. else:
  114. print(f'{prefix} quantization only supported on macOS, skipping...')
  115. ct_model.save(f)
  116. return f, ct_model
  117. @try_export
  118. def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
  119. # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
  120. assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
  121. try:
  122. import tensorrt as trt
  123. except Exception:
  124. if platform.system() == 'Linux':
  125. check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
  126. import tensorrt as trt
  127. if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
  128. grid = model.model[-1].anchor_grid
  129. model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
  130. export_onnx(model, im, file, 12, False, dynamic, simplify) # opset 12
  131. model.model[-1].anchor_grid = grid
  132. else: # TensorRT >= 8
  133. check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
  134. export_onnx(model, im, file, 13, False, dynamic, simplify) # opset 13
  135. onnx = file.with_suffix('.onnx')
  136. LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
  137. assert onnx.exists(), f'failed to export ONNX file: {onnx}'
  138. f = file.with_suffix('.engine') # TensorRT engine file
  139. logger = trt.Logger(trt.Logger.INFO)
  140. if verbose:
  141. logger.min_severity = trt.Logger.Severity.VERBOSE
  142. builder = trt.Builder(logger)
  143. config = builder.create_builder_config()
  144. config.max_workspace_size = workspace * 1 << 30
  145. # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
  146. flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
  147. network = builder.create_network(flag)
  148. parser = trt.OnnxParser(network, logger)
  149. if not parser.parse_from_file(str(onnx)):
  150. raise RuntimeError(f'failed to load ONNX file: {onnx}')
  151. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  152. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  153. LOGGER.info(f'{prefix} Network Description:')
  154. for inp in inputs:
  155. LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
  156. for out in outputs:
  157. LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
  158. if dynamic:
  159. if im.shape[0] <= 1:
  160. LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
  161. profile = builder.create_optimization_profile()
  162. for inp in inputs:
  163. profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
  164. config.add_optimization_profile(profile)
  165. LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
  166. if builder.platform_has_fast_fp16 and half:
  167. config.set_flag(trt.BuilderFlag.FP16)
  168. with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
  169. t.write(engine.serialize())
  170. return f, None
  171. @try_export
  172. def export_saved_model(model, im, file, dynamic, tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100,
  173. iou_thres=0.45, conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')):
  174. # YOLOv5 TensorFlow SavedModel export
  175. import tensorflow as tf
  176. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
  177. from models.tf import TFModel
  178. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  179. f = str(file).replace('.pt', '_saved_model')
  180. batch_size, ch, *imgsz = list(im.shape) # BCHW
  181. tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
  182. im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
  183. _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  184. inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
  185. outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
  186. keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
  187. keras_model.trainable = False
  188. keras_model.summary()
  189. if keras:
  190. keras_model.save(f, save_format='tf')
  191. else:
  192. spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
  193. m = tf.function(lambda x: keras_model(x)) # full model
  194. m = m.get_concrete_function(spec)
  195. frozen_func = convert_variables_to_constants_v2(m)
  196. tfm = tf.Module()
  197. tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec])
  198. tfm.__call__(im)
  199. tf.saved_model.save(tfm, f,
  200. options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
  201. tf.__version__, '2.6') else tf.saved_model.SaveOptions())
  202. return f, keras_model
  203. @try_export
  204. def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
  205. # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
  206. import tensorflow as tf
  207. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
  208. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  209. f = file.with_suffix('.pb')
  210. m = tf.function(lambda x: keras_model(x)) # full model
  211. m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
  212. frozen_func = convert_variables_to_constants_v2(m)
  213. frozen_func.graph.as_graph_def()
  214. tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
  215. return f, None
  216. @try_export
  217. def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
  218. # YOLOv5 TensorFlow Lite export
  219. import tensorflow as tf
  220. LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
  221. batch_size, ch, *imgsz = list(im.shape) # BCHW
  222. f = str(file).replace('.pt', '-fp16.tflite')
  223. converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
  224. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
  225. converter.target_spec.supported_types = [tf.float16]
  226. converter.optimizations = [tf.lite.Optimize.DEFAULT]
  227. if int8:
  228. from models.tf import representative_dataset_gen
  229. dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)
  230. converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
  231. converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
  232. converter.target_spec.supported_types = []
  233. converter.inference_input_type = tf.uint8 # or tf.int8
  234. converter.inference_output_type = tf.uint8 # or tf.int8
  235. converter.experimental_new_quantizer = True
  236. f = str(file).replace('.pt', '-int8.tflite')
  237. if nms or agnostic_nms:
  238. converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
  239. tflite_model = converter.convert()
  240. open(f, "wb").write(tflite_model)
  241. return f, None
  242. @try_export
  243. def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
  244. # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
  245. cmd = 'edgetpu_compiler --version'
  246. help_url = 'https://coral.ai/docs/edgetpu/compiler/'
  247. assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
  248. if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
  249. LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
  250. sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
  251. for c in ('curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
  252. 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
  253. 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
  254. subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
  255. ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
  256. LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
  257. f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
  258. f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
  259. cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
  260. subprocess.run(cmd.split(), check=True)
  261. return f, None
  262. @try_export
  263. def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
  264. # YOLOv5 TensorFlow.js export
  265. check_requirements(('tensorflowjs',))
  266. import re
  267. import tensorflowjs as tfjs
  268. LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
  269. f = str(file).replace('.pt', '_web_model') # js dir
  270. f_pb = file.with_suffix('.pb') # *.pb path
  271. f_json = f'{f}/model.json' # *.json path
  272. cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
  273. f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
  274. subprocess.run(cmd.split())
  275. json = Path(f_json).read_text()
  276. with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
  277. subst = re.sub(r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
  278. r'"Identity.?.?": {"name": "Identity.?.?"}, '
  279. r'"Identity.?.?": {"name": "Identity.?.?"}, '
  280. r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
  281. r'"Identity_1": {"name": "Identity_1"}, '
  282. r'"Identity_2": {"name": "Identity_2"}, '
  283. r'"Identity_3": {"name": "Identity_3"}}}', json)
  284. j.write(subst)
  285. return f, None
  286. @smart_inference_mode()
  287. def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
  288. weights=ROOT / 'yolov5s.pt', # weights path
  289. imgsz=(640, 640), # image (height, width)
  290. batch_size=1, # batch size
  291. device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  292. include=('torchscript', 'onnx'), # include formats
  293. half=False, # FP16 half-precision export
  294. inplace=False, # set YOLOv5 Detect() inplace=True
  295. train=False, # model.train() mode
  296. keras=False, # use Keras
  297. optimize=False, # TorchScript: optimize for mobile
  298. int8=False, # CoreML/TF INT8 quantization
  299. dynamic=False, # ONNX/TF/TensorRT: dynamic axes
  300. simplify=False, # ONNX: simplify model
  301. opset=12, # ONNX: opset version
  302. verbose=False, # TensorRT: verbose log
  303. workspace=4, # TensorRT: workspace size (GB)
  304. nms=False, # TF: add NMS to model
  305. agnostic_nms=False, # TF: add agnostic NMS to model
  306. topk_per_class=100, # TF.js NMS: topk per class to keep
  307. topk_all=100, # TF.js NMS: topk for all classes to keep
  308. iou_thres=0.45, # TF.js NMS: IoU threshold
  309. conf_thres=0.25, # TF.js NMS: confidence threshold
  310. ):
  311. t = time.time()
  312. include = [x.lower() for x in include] # to lowercase
  313. fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
  314. flags = [x in include for x in fmts]
  315. assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
  316. jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans
  317. file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
  318. # Load PyTorch model
  319. device = select_device(device)
  320. if half:
  321. assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
  322. assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
  323. model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
  324. # Checks
  325. imgsz *= 2 if len(imgsz) == 1 else 1 # expand
  326. if optimize:
  327. assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
  328. # Input
  329. gs = int(max(model.stride)) # grid size (max stride)
  330. imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
  331. im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
  332. # Update model
  333. model.train() if train else model.eval() # training mode = no Detect() layer grid construction
  334. for k, m in model.named_modules():
  335. if isinstance(m, Detect):
  336. m.inplace = inplace
  337. m.dynamic = dynamic
  338. m.export = True
  339. for _ in range(2):
  340. y = model(im) # dry runs
  341. if half and not coreml:
  342. im, model = im.half(), model.half() # to FP16
  343. shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
  344. LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
  345. # Exports
  346. f = [''] * 10 # exported filenames
  347. warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
  348. if engine: # TensorRT required before ONNX
  349. f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
  350. if onnx or xml: # OpenVINO requires ONNX
  351. f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
  352. if xml: # OpenVINO
  353. f[3], _ = export_openvino(model, file, half)
  354. if coreml:
  355. f[4], _ = export_coreml(model, im, file, int8, half)
  356. # TensorFlow Exports
  357. if any((saved_model, pb, tflite, edgetpu, tfjs)):
  358. if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
  359. check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
  360. assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
  361. f[5], model = export_saved_model(model.cpu(), im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
  362. agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class,
  363. topk_all=topk_all, iou_thres=iou_thres, conf_thres=conf_thres, keras=keras)
  364. if pb or tfjs: # pb prerequisite to tfjs
  365. f[6], _ = export_pb(model, file)
  366. if tflite or edgetpu:
  367. f[7], _ = export_tflite(model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
  368. if edgetpu:
  369. f[8], _ = export_edgetpu(file)
  370. if tfjs:
  371. f[9], _ = export_tfjs(file)
  372. # Finish
  373. f = [str(x) for x in f if x] # filter out '' and None
  374. if any(f):
  375. h = '--half' if half else '' # --half FP16 inference arg
  376. LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
  377. f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
  378. f"\nDetect: python detect.py --weights {f[-1]} {h}"
  379. f"\nValidate: python val.py --weights {f[-1]} {h}"
  380. f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"
  381. f"\nVisualize: https://netron.app")
  382. return f # return list of exported files/dirs
  383. def parse_opt():
  384. parser = argparse.ArgumentParser()
  385. parser.add_argument('--data', type=str, default='../data/data_y.yaml', help='dataset.yaml path')
  386. parser.add_argument('--weights', nargs='+', type=str, default='./runs/train-seg/exp/weights/best.pt',
  387. help='model.pt path(s)')
  388. parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
  389. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  390. parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  391. parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
  392. parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
  393. parser.add_argument('--train', action='store_true', help='model.train() mode')
  394. parser.add_argument('--keras', action='store_true', help='TF: use Keras')
  395. parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
  396. parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
  397. parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
  398. parser.add_argument('--simplify', action='store_true',default=True, help='ONNX: simplify model')
  399. parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
  400. parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
  401. parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
  402. parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
  403. parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
  404. parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
  405. parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
  406. parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
  407. parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
  408. parser.add_argument('--include', nargs='+', default=['onnx'],
  409. help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
  410. opt = parser.parse_args()
  411. print_args(vars(opt))
  412. return opt
  413. if __name__ == "__main__":
  414. opt = parse_opt()
  415. for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
  416. run(**vars(opt))

onnx结构:

转ncnn:

onnx2ncnn yolov7_mask.onnx yolov7_mask.param yolov7_mask.bin

报错:

Unsupported slice axes !
ScatterND not supported yet!
Unsupported slice axes !
ScatterND not supported yet!
Unsupported slice axes !
ScatterND not supported yet!

yolov7 mask python训练,tensorrt推理框架

整套框架已跑通,有需要的私信联系 AI视觉网奇的博客_CSDN博客-python宝典,深度学习宝典,pytorch知识宝典领域博主

python对比原图和结果:

  1. import json
  2. import os, cv2
  3. import numpy as np
  4. def visualization_bbox1(json_path, img_path): # 需要画的第num副图片, 对应的json路径和图片路径
  5. with open(json_path ,encoding='utf-8') as annos:
  6. annotation_json = json.load(annos)
  7. print('num_key is:', len(annotation_json) ,'json key is:', annotation_json.keys()) # 读出json文件的关键字
  8. print('json num_images is:', len(annotation_json['images'])) # json文件中包含的图片数量
  9. for img_i in range(len(annotation_json['images'])):
  10. image_name = annotation_json['images'][img_i]['file_name'] # 读取图片名
  11. id = annotation_json['images'][img_i]['id'] # 读取图片id
  12. image_path = os.path.join(img_path, str(image_name).zfill(5)) # 拼接图像路径
  13. image = cv2.imread(image_path, 1) # 保持原始格式的方式读取图像
  14. num_bbox = 0 # 统计一幅图片中bbox的数量
  15. coco_train = dict()
  16. coco_train['flags'] = {}
  17. coco_train['imagePath'] = os.path.basename(image_path)
  18. coco_train['shapes'] = []
  19. coco_train['imageData'] = None
  20. coco_train['imageHeight'] = image.shape[0]
  21. coco_train['imageWidth'] = image.shape[1]
  22. for i in range(len(annotation_json['annotations'][::])):
  23. if annotation_json['annotations'][i - 1]['image_id'] == id:
  24. num_bbox = num_bbox + 1
  25. box_dict={}
  26. box_dict["label"]="quexian"
  27. box_dict["shape_type"]="polygon"
  28. box_dict["points"]=[]
  29. x, y, w, h = annotation_json['annotations'][i - 1]['bbox'] # 读取边框
  30. image = cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (0, 255, 255), 2)
  31. points = annotation_json['annotations'][i - 1]['segmentation'] # keypoints
  32. data_len =len(points[0] )//2
  33. for index in range(data_len):
  34. cv2.circle(image, (int(points[0][index *2]) ,int(points[0][index * 2 +1])), 3, (0, 0, 213), -1) # x,y,r,color
  35. box_dict["points"].append((int(points[0][index *2]) ,int(points[0][index * 2 +1])))
  36. coco_train['shapes'].append(box_dict)
  37. # cv2.imwrite(f"gt/{img_i}.jpg" ,image)
  38. # cv2.resizeWindow("image_name", 2500, 1250) # 创建500*500的窗口
  39. if is_show:
  40. print(os.path.basename(image_path))
  41. # cv2.namedWindow("image_name", 0) # 创建窗口
  42. if image.shape[1] > 1000:
  43. x_scale = 1000 / image.shape[1]
  44. image = cv2.resize(image, None, fx=x_scale, fy=x_scale, interpolation=cv2.INTER_AREA)
  45. cv2.imshow("image_name", image)
  46. img_result=cv2.imread(result_dir+os.path.basename(image_path))
  47. if img_result.shape[1] > 1000:
  48. x_scale = 1000 / img_result.shape[1]
  49. img_result = cv2.resize(img_result, None, fx=x_scale, fy=x_scale, interpolation=cv2.INTER_AREA)
  50. cv2.imshow("img_result", img_result)
  51. cv2.waitKey(0)
  52. else:
  53. train_file=image_path[:-4]+".json"
  54. with open(train_file, 'w') as write_f:
  55. write_f.write(json.dumps(coco_train, indent=2, ensure_ascii=False))
  56. if __name__ == "__main__":
  57. is_show=True
  58. result_dir=r'D:\work\lbg\fenge\yolov7_mask-main\segment\runs\predict-seg\exp10/'
  59. train_json = r'D:\work\lbg\fenge\data/mark.json'
  60. train_path = r'D:\work\lbg\fenge\data\trainImage'
  61. os.makedirs(train_path, exist_ok=True)
  62. visualization_bbox1(train_json, train_path)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/275857
推荐阅读
相关标签
  

闽ICP备14008679号