当前位置:   article > 正文

yolov7利用onnx进行推理同时调用usb摄像头_yolov7 网络摄像头

yolov7 网络摄像头

最近再弄一个项目,需要用到yolov7但官方只发布了利用单张图片进行onnx推理的代码,网上一大堆对也基本是搬运,还一堆bug,博主在这里进行了改正!

官方onnx推理地址,感兴趣的可以去看看。官方onnx图片推理地址

按照官方的配置要求转换为onnx

python export.py --weights weights/yolov7.pt --grid --end2end --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640 --max-wh 640

博主版本 onnx ==1.9.0 onnx-simplifier == 0.3.6

ok,回归正文,直接贴代码。

  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import cv2
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from numpy import random
  8. import numpy
  9. import onnxruntime as ort
  10. from models.experimental import attempt_load
  11. from utils.datasets import LoadStreams, LoadImages
  12. from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
  13. scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
  14. from utils.plots import plot_one_box
  15. from utils.torch_utils import select_device, load_classifier, time_synchronized, TracedModel
  16. names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
  17. 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  18. 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
  19. 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
  20. 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
  21. 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
  22. 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  23. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
  24. 'hair drier', 'toothbrush']
  25. class ONNX_engine():
  26. def __init__(self, weights, size, cuda) -> None:
  27. self.img_new_shape = (size, size)
  28. self.weights = weights
  29. self.device = cuda
  30. self.init_engine()
  31. self.names = 'names'
  32. self.colors = {name: [random.randint(0, 255) for _ in range(3)] for i, name in enumerate(self.names)}
  33. def init_engine(self):
  34. providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.device else ['CPUExecutionProvider']
  35. self.session = ort.InferenceSession(self.weights[0], providers=providers)
  36. def predict(self, im):
  37. outname = [i.name for i in self.session.get_outputs()]
  38. inname = [i.name for i in self.session.get_inputs()]
  39. inp = {inname[0]: im}
  40. outputs = self.session.run(outname, inp)[0]
  41. # print(outputs.shape)
  42. return outputs
  43. def detect(save_img=False):
  44. source, weights, view_img, save_txt, imgsz, trace = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, not opt.no_trace
  45. save_img = not opt.nosave and not source.endswith('.txt') # save inference images
  46. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  47. ('rtsp://', 'rtmp://', 'http://', 'https://'))
  48. # Directories
  49. save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
  50. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  51. # Initialize
  52. set_logging()
  53. device = select_device(opt.device)
  54. half = device.type != 'cpu' # half precision only supported on CUDA
  55. # Load model
  56. model = ONNX_engine(weights, imgsz, device)
  57. stride = 32
  58. vid_path, vid_writer = None, None
  59. if webcam:
  60. view_img = check_imshow()
  61. cudnn.benchmark = True # set True to speed up constant image size inference
  62. dataset = LoadStreams(source, img_size=imgsz, stride=stride,onnx=True)
  63. else:
  64. dataset = LoadImages(source, img_size=imgsz, stride=stride)
  65. # Get names and colors
  66. # names = model.module.names if hasattr(model, 'module') else model.names
  67. colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
  68. # Run inference
  69. if device.type != 'cpu':
  70. model.predict(torch.zeros(1, 3, imgsz, imgsz))# run once
  71. old_img_w = old_img_h = imgsz
  72. old_img_b = 1
  73. t0 = time.time()
  74. for path, img, im0s, vid_cap in dataset:
  75. img = torch.from_numpy(img).to(device)
  76. img = img.half() if half else img.float() # uint8 to fp16/32
  77. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  78. if img.ndimension() == 3:
  79. img = img.unsqueeze(0)
  80. # Warmup
  81. if device.type != 'cpu' and (old_img_b != img.shape[0] or old_img_h != img.shape[2] or old_img_w != img.shape[3]):
  82. old_img_b = img.shape[0]
  83. old_img_h = img.shape[2]
  84. old_img_w = img.shape[3]
  85. for i in range(3):
  86. model.predict(numpy.array(img))
  87. # Inference
  88. t1 = time_synchronized()
  89. with torch.no_grad(): # Calculating gradients would cause a GPU memory leak
  90. pred = torch.from_numpy(model.predict(numpy.array(img))).unsqueeze(0)
  91. t2 = time_synchronized()
  92. # Apply NMS
  93. # pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  94. t3 = time_synchronized()
  95. # if pred.size == 0:
  96. # continue
  97. # else:
  98. # Process detections
  99. for i, det in enumerate(pred): # detections per image
  100. if webcam: # batch_size >= 1
  101. p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
  102. else:
  103. p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
  104. p = Path(p) # to Path
  105. save_path = str(save_dir / p.name) # img.jpg
  106. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  107. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  108. if len(det):
  109. # Rescale boxes from img_size to im0 size
  110. det[:, 1:5] = scale_coords(img.shape[2:], det[:, 1:5], im0.shape).round()
  111. # Print results
  112. for c in det[:, -2].unique():
  113. n = (det[:, -2] == c).sum() # detections per class
  114. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  115. # Write results
  116. for batch_id,*xyxy, cls, conf in reversed(det):
  117. if save_txt: # Write to file
  118. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  119. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  120. with open(txt_path + '.txt', 'a') as f:
  121. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  122. if save_img or view_img: # Add bbox to image
  123. label = f'{names[int(cls)]} {conf:.2f}'
  124. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=1)
  125. # Print time (inference + NMS)
  126. print(f'{s}Done. ({(1E3 * (t2 - t1)):.1f}ms) Inference')
  127. # Stream results
  128. if view_img:
  129. cv2.imshow(str(p), im0)
  130. cv2.waitKey(1) # 1 millisecond
  131. # Save results (image with detections)
  132. if save_img:
  133. if dataset.mode == 'image':
  134. cv2.imwrite(save_path, im0)
  135. print(f" The image with the result is saved in: {save_path}")
  136. else: # 'video' or 'stream'
  137. if vid_path != save_path: # new video
  138. vid_path = save_path
  139. if isinstance(vid_writer, cv2.VideoWriter):
  140. vid_writer.release() # release previous video writer
  141. if vid_cap: # video
  142. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  143. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  144. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  145. else: # stream
  146. fps, w, h = 30, im0.shape[1], im0.shape[0]
  147. save_path += '.mp4'
  148. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  149. vid_writer.write(im0)
  150. if save_txt or save_img:
  151. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  152. #print(f"Results saved to {save_dir}{s}")
  153. print(f'Done. ({time.time() - t0:.3f}s)')
  154. if __name__ == '__main__':
  155. parser = argparse.ArgumentParser()
  156. parser.add_argument('--weights', nargs='+', type=str, default='yolov7.pt', help='model.pt path(s)')
  157. parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam
  158. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  159. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  160. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  161. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  162. parser.add_argument('--view-img', action='store_false', help='display results')
  163. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  164. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  165. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  166. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  167. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  168. parser.add_argument('--augment', action='store_true', help='augmented inference')
  169. parser.add_argument('--update', action='store_true', help='update all models')
  170. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  171. parser.add_argument('--name', default='exp', help='save results to project/name')
  172. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  173. parser.add_argument('--no-trace', action='store_false', help='don`t trace model')
  174. opt = parser.parse_args()
  175. print(opt)
  176. #check_requirements(exclude=('pycocotools', 'thop'))
  177. with torch.no_grad():
  178. if opt.update: # update all models (to fix SourceChangeWarning)
  179. for opt.weights in ['yolov7.pt']:
  180. detect()
  181. strip_optimizer(opt.weights)
  182. else:
  183. detect()

这是根据detect和官方onnx推理改的代码,直接和detect.py放同一路径下,同时dataset.py不能用minum rectangle,在函数letterbox中,必须长宽相等。改动了部分整个贴出,避免漏掉细节!以下是dataset.py代码直接替换原来的文件就可以了。

  1. # Dataset utils and dataloaders
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. import time
  9. from itertools import repeat
  10. from multiprocessing.pool import ThreadPool
  11. from pathlib import Path
  12. from threading import Thread
  13. import cv2
  14. import numpy as np
  15. import torch
  16. import torch.nn.functional as F
  17. from PIL import Image, ExifTags
  18. from torch.utils.data import Dataset
  19. from tqdm import tqdm
  20. import pickle
  21. from copy import deepcopy
  22. # from pycocotools import mask as maskUtils
  23. from torchvision.utils import save_image
  24. from torchvision.ops import roi_pool, roi_align, ps_roi_pool, ps_roi_align
  25. from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \
  26. resample_segments, clean_str
  27. from utils.torch_utils import torch_distributed_zero_first
  28. # Parameters
  29. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  30. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  31. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  32. logger = logging.getLogger(__name__)
  33. # Get orientation exif tag
  34. for orientation in ExifTags.TAGS.keys():
  35. if ExifTags.TAGS[orientation] == 'Orientation':
  36. break
  37. def get_hash(files):
  38. # Returns a single hash value of a list of files
  39. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  40. def exif_size(img):
  41. # Returns exif-corrected PIL size
  42. s = img.size # (width, height)
  43. try:
  44. rotation = dict(img._getexif().items())[orientation]
  45. if rotation == 6: # rotation 270
  46. s = (s[1], s[0])
  47. elif rotation == 8: # rotation 90
  48. s = (s[1], s[0])
  49. except:
  50. pass
  51. return s
  52. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  53. rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
  54. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  55. with torch_distributed_zero_first(rank):
  56. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  57. augment=augment, # augment images
  58. hyp=hyp, # augmentation hyperparameters
  59. rect=rect, # rectangular training
  60. cache_images=cache,
  61. single_cls=opt.single_cls,
  62. stride=int(stride),
  63. pad=pad,
  64. image_weights=image_weights,
  65. prefix=prefix)
  66. batch_size = min(batch_size, len(dataset))
  67. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  68. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  69. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  70. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  71. dataloader = loader(dataset,
  72. batch_size=batch_size,
  73. num_workers=nw,
  74. sampler=sampler,
  75. pin_memory=True,
  76. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  77. return dataloader, dataset
  78. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  79. """ Dataloader that reuses workers
  80. Uses same syntax as vanilla DataLoader
  81. """
  82. def __init__(self, *args, **kwargs):
  83. super().__init__(*args, **kwargs)
  84. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  85. self.iterator = super().__iter__()
  86. def __len__(self):
  87. return len(self.batch_sampler.sampler)
  88. def __iter__(self):
  89. for i in range(len(self)):
  90. yield next(self.iterator)
  91. class _RepeatSampler(object):
  92. """ Sampler that repeats forever
  93. Args:
  94. sampler (Sampler)
  95. """
  96. def __init__(self, sampler):
  97. self.sampler = sampler
  98. def __iter__(self):
  99. while True:
  100. yield from iter(self.sampler)
  101. class LoadImages: # for inference
  102. def __init__(self, path, img_size=640, stride=32):
  103. p = str(Path(path).absolute()) # os-agnostic absolute path
  104. if '*' in p:
  105. files = sorted(glob.glob(p, recursive=True)) # glob
  106. elif os.path.isdir(p):
  107. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  108. elif os.path.isfile(p):
  109. files = [p] # files
  110. else:
  111. raise Exception(f'ERROR: {p} does not exist')
  112. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  113. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  114. ni, nv = len(images), len(videos)
  115. self.img_size = img_size
  116. self.stride = stride
  117. self.files = images + videos
  118. self.nf = ni + nv # number of files
  119. self.video_flag = [False] * ni + [True] * nv
  120. self.mode = 'image'
  121. if any(videos):
  122. self.new_video(videos[0]) # new video
  123. else:
  124. self.cap = None
  125. assert self.nf > 0, f'No images or videos found in {p}. ' \
  126. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  127. def __iter__(self):
  128. self.count = 0
  129. return self
  130. def __next__(self):
  131. if self.count == self.nf:
  132. raise StopIteration
  133. path = self.files[self.count]
  134. if self.video_flag[self.count]:
  135. # Read video
  136. self.mode = 'video'
  137. ret_val, img0 = self.cap.read()
  138. if not ret_val:
  139. self.count += 1
  140. self.cap.release()
  141. if self.count == self.nf: # last video
  142. raise StopIteration
  143. else:
  144. path = self.files[self.count]
  145. self.new_video(path)
  146. ret_val, img0 = self.cap.read()
  147. self.frame += 1
  148. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
  149. else:
  150. # Read image
  151. self.count += 1
  152. img0 = cv2.imread(path) # BGR
  153. assert img0 is not None, 'Image Not Found ' + path
  154. # print(f'image {self.count}/{self.nf} {path}: ', end='')
  155. # Padded resize
  156. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  157. # Convert
  158. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  159. img = np.ascontiguousarray(img)
  160. return path, img, img0, self.cap
  161. def new_video(self, path):
  162. self.frame = 0
  163. self.cap = cv2.VideoCapture(path)
  164. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  165. def __len__(self):
  166. return self.nf # number of files
  167. class LoadWebcam: # for inference
  168. def __init__(self, pipe='0', img_size=640, stride=32):
  169. self.img_size = img_size
  170. self.stride = stride
  171. if pipe.isnumeric():
  172. pipe = eval(pipe) # local camera
  173. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  174. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  175. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  176. self.pipe = pipe
  177. self.cap = cv2.VideoCapture(pipe) # video capture object
  178. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  179. def __iter__(self):
  180. self.count = -1
  181. return self
  182. def __next__(self):
  183. self.count += 1
  184. if cv2.waitKey(1) == ord('q'): # q to quit
  185. self.cap.release()
  186. cv2.destroyAllWindows()
  187. raise StopIteration
  188. # Read frame
  189. if self.pipe == 0: # local camera
  190. ret_val, img0 = self.cap.read()
  191. img0 = cv2.flip(img0, 1) # flip left-right
  192. else: # IP camera
  193. n = 0
  194. while True:
  195. n += 1
  196. self.cap.grab()
  197. if n % 30 == 0: # skip frames
  198. ret_val, img0 = self.cap.retrieve()
  199. if ret_val:
  200. break
  201. # Print
  202. assert ret_val, f'Camera Error {self.pipe}'
  203. img_path = 'webcam.jpg'
  204. print(f'webcam {self.count}: ', end='')
  205. # Padded resize
  206. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  207. # Convert
  208. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  209. img = np.ascontiguousarray(img)
  210. return img_path, img, img0, None
  211. def __len__(self):
  212. return 0
  213. class LoadStreams: # multiple IP or RTSP cameras
  214. def __init__(self, sources='streams.txt', img_size=640, stride=32,onnx=True):
  215. self.mode = 'stream'
  216. self.img_size = img_size
  217. self.stride = stride
  218. if os.path.isfile(sources):
  219. with open(sources, 'r') as f:
  220. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  221. else:
  222. sources = [sources]
  223. n = len(sources)
  224. self.imgs = [None] * n
  225. self.sources = [clean_str(x) for x in sources] # clean source names for later
  226. for i, s in enumerate(sources):
  227. # Start the thread to read frames from the video stream
  228. print(f'{i + 1}/{n}: {s}... ', end='')
  229. url = eval(s) if s.isnumeric() else s
  230. if 'youtube.com/' in str(url) or 'youtu.be/' in str(url): # if source is YouTube video
  231. check_requirements(('pafy', 'youtube_dl'))
  232. import pafy
  233. url = pafy.new(url).getbest(preftype="mp4").url
  234. cap = cv2.VideoCapture(url)
  235. assert cap.isOpened(), f'Failed to open {s}'
  236. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  237. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  238. self.fps = (cap.get(cv2.CAP_PROP_FPS) % 100) + 1
  239. _, self.imgs[i] = cap.read() # guarantee first frame
  240. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  241. print(f' success ({w}x{h} at {self.fps:.2f} FPS).')
  242. thread.start()
  243. print('') # newline
  244. # check for common shapes
  245. s = np.stack([letterbox(x, self.img_size, stride=self.stride,onnx=onnx)[0].shape for x in self.imgs], 0) # shapes
  246. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  247. if not self.rect:
  248. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  249. def update(self, index, cap):
  250. # Read next stream frame in a daemon thread
  251. n = 0
  252. while cap.isOpened():
  253. n += 1
  254. # _, self.imgs[index] = cap.read()
  255. cap.grab()
  256. if n == 4: # read every 4th frame
  257. success, im = cap.retrieve()
  258. self.imgs[index] = im if success else self.imgs[index] * 0
  259. n = 0
  260. time.sleep(1 / self.fps) # wait time
  261. def __iter__(self):
  262. self.count = -1
  263. return self
  264. def __next__(self):
  265. self.count += 1
  266. img0 = self.imgs.copy()
  267. if cv2.waitKey(1) == ord('q'): # q to quit
  268. cv2.destroyAllWindows()
  269. raise StopIteration
  270. # Letterbox
  271. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  272. # Stack
  273. img = np.stack(img, 0)
  274. # Convert
  275. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  276. img = np.ascontiguousarray(img)
  277. return self.sources, img, img0, None
  278. def __len__(self):
  279. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  280. def img2label_paths(img_paths):
  281. # Define label paths as a function of image paths
  282. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  283. return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
  284. class LoadImagesAndLabels(Dataset): # for training/testing
  285. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  286. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  287. self.img_size = img_size
  288. self.augment = augment
  289. self.hyp = hyp
  290. self.image_weights = image_weights
  291. self.rect = False if image_weights else rect
  292. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  293. self.mosaic_border = [-img_size // 2, -img_size // 2]
  294. self.stride = stride
  295. self.path = path
  296. # self.albumentations = Albumentations() if augment else None
  297. try:
  298. f = [] # image files
  299. for p in path if isinstance(path, list) else [path]:
  300. p = Path(p) # os-agnostic
  301. if p.is_dir(): # dir
  302. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  303. # f = list(p.rglob('**/*.*')) # pathlib
  304. elif p.is_file(): # file
  305. with open(p, 'r') as t:
  306. t = t.read().strip().splitlines()
  307. parent = str(p.parent) + os.sep
  308. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  309. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  310. else:
  311. raise Exception(f'{prefix}{p} does not exist')
  312. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  313. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
  314. assert self.img_files, f'{prefix}No images found'
  315. except Exception as e:
  316. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  317. # Check cache
  318. self.label_files = img2label_paths(self.img_files) # labels
  319. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
  320. if cache_path.is_file():
  321. cache, exists = torch.load(cache_path), True # load
  322. # if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed
  323. # cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
  324. else:
  325. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  326. # Display cache
  327. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  328. if exists:
  329. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  330. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  331. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
  332. # Read cache
  333. cache.pop('hash') # remove hash
  334. cache.pop('version') # remove version
  335. labels, shapes, self.segments = zip(*cache.values())
  336. self.labels = list(labels)
  337. self.shapes = np.array(shapes, dtype=np.float64)
  338. self.img_files = list(cache.keys()) # update
  339. self.label_files = img2label_paths(cache.keys()) # update
  340. if single_cls:
  341. for x in self.labels:
  342. x[:, 0] = 0
  343. n = len(shapes) # number of images
  344. bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
  345. nb = bi[-1] + 1 # number of batches
  346. self.batch = bi # batch index of image
  347. self.n = n
  348. self.indices = range(n)
  349. # Rectangular Training
  350. if self.rect:
  351. # Sort by aspect ratio
  352. s = self.shapes # wh
  353. ar = s[:, 1] / s[:, 0] # aspect ratio
  354. irect = ar.argsort()
  355. self.img_files = [self.img_files[i] for i in irect]
  356. self.label_files = [self.label_files[i] for i in irect]
  357. self.labels = [self.labels[i] for i in irect]
  358. self.shapes = s[irect] # wh
  359. ar = ar[irect]
  360. # Set training image shapes
  361. shapes = [[1, 1]] * nb
  362. for i in range(nb):
  363. ari = ar[bi == i]
  364. mini, maxi = ari.min(), ari.max()
  365. if maxi < 1:
  366. shapes[i] = [maxi, 1]
  367. elif mini > 1:
  368. shapes[i] = [1, 1 / mini]
  369. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
  370. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  371. self.imgs = [None] * n
  372. if cache_images:
  373. if cache_images == 'disk':
  374. self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
  375. self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
  376. self.im_cache_dir.mkdir(parents=True, exist_ok=True)
  377. gb = 0 # Gigabytes of cached images
  378. self.img_hw0, self.img_hw = [None] * n, [None] * n
  379. results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
  380. pbar = tqdm(enumerate(results), total=n)
  381. for i, x in pbar:
  382. if cache_images == 'disk':
  383. if not self.img_npy[i].exists():
  384. np.save(self.img_npy[i].as_posix(), x[0])
  385. gb += self.img_npy[i].stat().st_size
  386. else:
  387. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x
  388. gb += self.imgs[i].nbytes
  389. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
  390. pbar.close()
  391. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  392. # Cache dataset labels, check images and read shapes
  393. x = {} # dict
  394. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
  395. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  396. for i, (im_file, lb_file) in enumerate(pbar):
  397. try:
  398. # verify images
  399. im = Image.open(im_file)
  400. im.verify() # PIL verify
  401. shape = exif_size(im) # image size
  402. segments = [] # instance segments
  403. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  404. assert im.format.lower() in img_formats, f'invalid image format {im.format}'
  405. # verify labels
  406. if os.path.isfile(lb_file):
  407. nf += 1 # label found
  408. with open(lb_file, 'r') as f:
  409. l = [x.split() for x in f.read().strip().splitlines()]
  410. if any([len(x) > 8 for x in l]): # is segment
  411. classes = np.array([x[0] for x in l], dtype=np.float32)
  412. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  413. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  414. l = np.array(l, dtype=np.float32)
  415. if len(l):
  416. assert l.shape[1] == 5, 'labels require 5 columns each'
  417. assert (l >= 0).all(), 'negative labels'
  418. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  419. assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
  420. else:
  421. ne += 1 # label empty
  422. l = np.zeros((0, 5), dtype=np.float32)
  423. else:
  424. nm += 1 # label missing
  425. l = np.zeros((0, 5), dtype=np.float32)
  426. x[im_file] = [l, shape, segments]
  427. except Exception as e:
  428. nc += 1
  429. print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
  430. pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
  431. f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  432. pbar.close()
  433. if nf == 0:
  434. print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
  435. x['hash'] = get_hash(self.label_files + self.img_files)
  436. x['results'] = nf, nm, ne, nc, i + 1
  437. x['version'] = 0.1 # cache version
  438. torch.save(x, path) # save for next time
  439. logging.info(f'{prefix}New cache created: {path}')
  440. return x
  441. def __len__(self):
  442. return len(self.img_files)
  443. # def __iter__(self):
  444. # self.count = -1
  445. # print('ran dataset iter')
  446. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  447. # return self
  448. def __getitem__(self, index):
  449. index = self.indices[index] # linear, shuffled, or image_weights
  450. hyp = self.hyp
  451. mosaic = self.mosaic and random.random() < hyp['mosaic']
  452. if mosaic:
  453. # Load mosaic
  454. if random.random() < 0.8:
  455. img, labels = load_mosaic(self, index)
  456. else:
  457. img, labels = load_mosaic9(self, index)
  458. shapes = None
  459. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  460. if random.random() < hyp['mixup']:
  461. if random.random() < 0.8:
  462. img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
  463. else:
  464. img2, labels2 = load_mosaic9(self, random.randint(0, len(self.labels) - 1))
  465. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  466. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  467. labels = np.concatenate((labels, labels2), 0)
  468. else:
  469. # Load image
  470. img, (h0, w0), (h, w) = load_image(self, index)
  471. # Letterbox
  472. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  473. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  474. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  475. labels = self.labels[index].copy()
  476. if labels.size: # normalized xywh to pixel xyxy format
  477. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  478. if self.augment:
  479. # Augment imagespace
  480. if not mosaic:
  481. img, labels = random_perspective(img, labels,
  482. degrees=hyp['degrees'],
  483. translate=hyp['translate'],
  484. scale=hyp['scale'],
  485. shear=hyp['shear'],
  486. perspective=hyp['perspective'])
  487. # img, labels = self.albumentations(img, labels)
  488. # Augment colorspace
  489. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  490. # Apply cutouts
  491. # if random.random() < 0.9:
  492. # labels = cutout(img, labels)
  493. if random.random() < hyp['paste_in']:
  494. sample_labels, sample_images, sample_masks = [], [], []
  495. while len(sample_labels) < 30:
  496. sample_labels_, sample_images_, sample_masks_ = load_samples(self, random.randint(0,
  497. len(self.labels) - 1))
  498. sample_labels += sample_labels_
  499. sample_images += sample_images_
  500. sample_masks += sample_masks_
  501. # print(len(sample_labels))
  502. if len(sample_labels) == 0:
  503. break
  504. labels = pastein(img, labels, sample_labels, sample_images, sample_masks)
  505. nL = len(labels) # number of labels
  506. if nL:
  507. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  508. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  509. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  510. if self.augment:
  511. # flip up-down
  512. if random.random() < hyp['flipud']:
  513. img = np.flipud(img)
  514. if nL:
  515. labels[:, 2] = 1 - labels[:, 2]
  516. # flip left-right
  517. if random.random() < hyp['fliplr']:
  518. img = np.fliplr(img)
  519. if nL:
  520. labels[:, 1] = 1 - labels[:, 1]
  521. labels_out = torch.zeros((nL, 6))
  522. if nL:
  523. labels_out[:, 1:] = torch.from_numpy(labels)
  524. # Convert
  525. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  526. img = np.ascontiguousarray(img)
  527. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  528. @staticmethod
  529. def collate_fn(batch):
  530. img, label, path, shapes = zip(*batch) # transposed
  531. for i, l in enumerate(label):
  532. l[:, 0] = i # add target image index for build_targets()
  533. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  534. @staticmethod
  535. def collate_fn4(batch):
  536. img, label, path, shapes = zip(*batch) # transposed
  537. n = len(shapes) // 4
  538. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  539. ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
  540. wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
  541. s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
  542. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  543. i *= 4
  544. if random.random() < 0.5:
  545. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
  546. 0].type(img[i].type())
  547. l = label[i]
  548. else:
  549. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  550. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  551. img4.append(im)
  552. label4.append(l)
  553. for i, l in enumerate(label4):
  554. l[:, 0] = i # add target image index for build_targets()
  555. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  556. # Ancillary functions --------------------------------------------------------------------------------------------------
  557. def load_image(self, index):
  558. # loads 1 image from dataset, returns img, original hw, resized hw
  559. img = self.imgs[index]
  560. if img is None: # not cached
  561. path = self.img_files[index]
  562. img = cv2.imread(path) # BGR
  563. assert img is not None, 'Image Not Found ' + path
  564. h0, w0 = img.shape[:2] # orig hw
  565. r = self.img_size / max(h0, w0) # resize image to img_size
  566. if r != 1: # always resize down, only resize up if training with augmentation
  567. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  568. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  569. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  570. else:
  571. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  572. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  573. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  574. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  575. dtype = img.dtype # uint8
  576. x = np.arange(0, 256, dtype=np.int16)
  577. lut_hue = ((x * r[0]) % 180).astype(dtype)
  578. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  579. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  580. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  581. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  582. def hist_equalize(img, clahe=True, bgr=False):
  583. # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
  584. yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
  585. if clahe:
  586. c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  587. yuv[:, :, 0] = c.apply(yuv[:, :, 0])
  588. else:
  589. yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
  590. return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
  591. def load_mosaic(self, index):
  592. # loads images in a 4-mosaic
  593. labels4, segments4 = [], []
  594. s = self.img_size
  595. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  596. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  597. for i, index in enumerate(indices):
  598. # Load image
  599. img, _, (h, w) = load_image(self, index)
  600. # place img in img4
  601. if i == 0: # top left
  602. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  603. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  604. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  605. elif i == 1: # top right
  606. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  607. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  608. elif i == 2: # bottom left
  609. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  610. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  611. elif i == 3: # bottom right
  612. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  613. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  614. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  615. padw = x1a - x1b
  616. padh = y1a - y1b
  617. # Labels
  618. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  619. if labels.size:
  620. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  621. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  622. labels4.append(labels)
  623. segments4.extend(segments)
  624. # Concat/clip labels
  625. labels4 = np.concatenate(labels4, 0)
  626. for x in (labels4[:, 1:], *segments4):
  627. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  628. # img4, labels4 = replicate(img4, labels4) # replicate
  629. # Augment
  630. # img4, labels4, segments4 = remove_background(img4, labels4, segments4)
  631. # sample_segments(img4, labels4, segments4, probability=self.hyp['copy_paste'])
  632. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=self.hyp['copy_paste'])
  633. img4, labels4 = random_perspective(img4, labels4, segments4,
  634. degrees=self.hyp['degrees'],
  635. translate=self.hyp['translate'],
  636. scale=self.hyp['scale'],
  637. shear=self.hyp['shear'],
  638. perspective=self.hyp['perspective'],
  639. border=self.mosaic_border) # border to remove
  640. return img4, labels4
  641. def load_mosaic9(self, index):
  642. # loads images in a 9-mosaic
  643. labels9, segments9 = [], []
  644. s = self.img_size
  645. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  646. for i, index in enumerate(indices):
  647. # Load image
  648. img, _, (h, w) = load_image(self, index)
  649. # place img in img9
  650. if i == 0: # center
  651. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  652. h0, w0 = h, w
  653. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  654. elif i == 1: # top
  655. c = s, s - h, s + w, s
  656. elif i == 2: # top right
  657. c = s + wp, s - h, s + wp + w, s
  658. elif i == 3: # right
  659. c = s + w0, s, s + w0 + w, s + h
  660. elif i == 4: # bottom right
  661. c = s + w0, s + hp, s + w0 + w, s + hp + h
  662. elif i == 5: # bottom
  663. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  664. elif i == 6: # bottom left
  665. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  666. elif i == 7: # left
  667. c = s - w, s + h0 - h, s, s + h0
  668. elif i == 8: # top left
  669. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  670. padx, pady = c[:2]
  671. x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
  672. # Labels
  673. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  674. if labels.size:
  675. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  676. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  677. labels9.append(labels)
  678. segments9.extend(segments)
  679. # Image
  680. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  681. hp, wp = h, w # height, width previous
  682. # Offset
  683. yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
  684. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  685. # Concat/clip labels
  686. labels9 = np.concatenate(labels9, 0)
  687. labels9[:, [1, 3]] -= xc
  688. labels9[:, [2, 4]] -= yc
  689. c = np.array([xc, yc]) # centers
  690. segments9 = [x - c for x in segments9]
  691. for x in (labels9[:, 1:], *segments9):
  692. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  693. # img9, labels9 = replicate(img9, labels9) # replicate
  694. # Augment
  695. # img9, labels9, segments9 = remove_background(img9, labels9, segments9)
  696. img9, labels9, segments9 = copy_paste(img9, labels9, segments9, probability=self.hyp['copy_paste'])
  697. img9, labels9 = random_perspective(img9, labels9, segments9,
  698. degrees=self.hyp['degrees'],
  699. translate=self.hyp['translate'],
  700. scale=self.hyp['scale'],
  701. shear=self.hyp['shear'],
  702. perspective=self.hyp['perspective'],
  703. border=self.mosaic_border) # border to remove
  704. return img9, labels9
  705. def load_samples(self, index):
  706. # loads images in a 4-mosaic
  707. labels4, segments4 = [], []
  708. s = self.img_size
  709. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  710. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  711. for i, index in enumerate(indices):
  712. # Load image
  713. img, _, (h, w) = load_image(self, index)
  714. # place img in img4
  715. if i == 0: # top left
  716. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  717. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  718. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  719. elif i == 1: # top right
  720. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  721. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  722. elif i == 2: # bottom left
  723. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  724. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  725. elif i == 3: # bottom right
  726. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  727. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  728. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  729. padw = x1a - x1b
  730. padh = y1a - y1b
  731. # Labels
  732. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  733. if labels.size:
  734. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  735. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  736. labels4.append(labels)
  737. segments4.extend(segments)
  738. # Concat/clip labels
  739. labels4 = np.concatenate(labels4, 0)
  740. for x in (labels4[:, 1:], *segments4):
  741. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  742. # img4, labels4 = replicate(img4, labels4) # replicate
  743. # Augment
  744. # img4, labels4, segments4 = remove_background(img4, labels4, segments4)
  745. sample_labels, sample_images, sample_masks = sample_segments(img4, labels4, segments4, probability=0.5)
  746. return sample_labels, sample_images, sample_masks
  747. def copy_paste(img, labels, segments, probability=0.5):
  748. # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
  749. n = len(segments)
  750. if probability and n:
  751. h, w, c = img.shape # height, width, channels
  752. im_new = np.zeros(img.shape, np.uint8)
  753. for j in random.sample(range(n), k=round(probability * n)):
  754. l, s = labels[j], segments[j]
  755. box = w - l[3], l[2], w - l[1], l[4]
  756. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  757. if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
  758. labels = np.concatenate((labels, [[l[0], *box]]), 0)
  759. segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
  760. cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
  761. result = cv2.bitwise_and(src1=img, src2=im_new)
  762. result = cv2.flip(result, 1) # augment segments (flip left-right)
  763. i = result > 0 # pixels to replace
  764. # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
  765. img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug
  766. return img, labels, segments
  767. def remove_background(img, labels, segments):
  768. # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
  769. n = len(segments)
  770. h, w, c = img.shape # height, width, channels
  771. im_new = np.zeros(img.shape, np.uint8)
  772. img_new = np.ones(img.shape, np.uint8) * 114
  773. for j in range(n):
  774. cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
  775. result = cv2.bitwise_and(src1=img, src2=im_new)
  776. i = result > 0 # pixels to replace
  777. img_new[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug
  778. return img_new, labels, segments
  779. def sample_segments(img, labels, segments, probability=0.5):
  780. # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
  781. n = len(segments)
  782. sample_labels = []
  783. sample_images = []
  784. sample_masks = []
  785. if probability and n:
  786. h, w, c = img.shape # height, width, channels
  787. for j in random.sample(range(n), k=round(probability * n)):
  788. l, s = labels[j], segments[j]
  789. box = l[1].astype(int).clip(0, w - 1), l[2].astype(int).clip(0, h - 1), l[3].astype(int).clip(0, w - 1), l[
  790. 4].astype(int).clip(0, h - 1)
  791. # print(box)
  792. if (box[2] <= box[0]) or (box[3] <= box[1]):
  793. continue
  794. sample_labels.append(l[0])
  795. mask = np.zeros(img.shape, np.uint8)
  796. cv2.drawContours(mask, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
  797. sample_masks.append(mask[box[1]:box[3], box[0]:box[2], :])
  798. result = cv2.bitwise_and(src1=img, src2=mask)
  799. i = result > 0 # pixels to replace
  800. mask[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug
  801. # print(box)
  802. sample_images.append(mask[box[1]:box[3], box[0]:box[2], :])
  803. return sample_labels, sample_images, sample_masks
  804. def replicate(img, labels):
  805. # Replicate labels
  806. h, w = img.shape[:2]
  807. boxes = labels[:, 1:].astype(int)
  808. x1, y1, x2, y2 = boxes.T
  809. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  810. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  811. x1b, y1b, x2b, y2b = boxes[i]
  812. bh, bw = y2b - y1b, x2b - x1b
  813. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  814. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  815. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  816. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  817. return img, labels
  818. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=True, scaleup=True,onnx=False,stride=32):
  819. # Resize and pad image while meeting stride-multiple constraints
  820. if onnx:
  821. auto,scaleFill = False,True
  822. shape = img.shape[:2] # current shape [height, width]
  823. if isinstance(new_shape, int):
  824. new_shape = (new_shape, new_shape)
  825. # Scale ratio (new / old)
  826. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  827. if not scaleup: # only scale down, do not scale up (for better test mAP)
  828. r = min(r, 1.0)
  829. if auto: # minimum rectangle
  830. # Compute padding
  831. ratio = r, r # width, height ratios
  832. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  833. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  834. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  835. if scaleFill: # stretch
  836. new_unpad = (new_shape[1], new_shape[0])
  837. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  838. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  839. dw /= 2 # divide padding into 2 sides
  840. dh /= 2
  841. if shape[::-1] != new_unpad: # resize
  842. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  843. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  844. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  845. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  846. return img, ratio, (dw, dh)
  847. def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
  848. border=(0, 0)):
  849. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  850. # targets = [cls, xyxy]
  851. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  852. width = img.shape[1] + border[1] * 2
  853. # Center
  854. C = np.eye(3)
  855. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  856. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  857. # Perspective
  858. P = np.eye(3)
  859. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  860. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  861. # Rotation and Scale
  862. R = np.eye(3)
  863. a = random.uniform(-degrees, degrees)
  864. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  865. s = random.uniform(1 - scale, 1.1 + scale)
  866. # s = 2 ** random.uniform(-scale, scale)
  867. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  868. # Shear
  869. S = np.eye(3)
  870. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  871. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  872. # Translation
  873. T = np.eye(3)
  874. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  875. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  876. # Combined rotation matrix
  877. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  878. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  879. if perspective:
  880. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  881. else: # affine
  882. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  883. # Visualize
  884. # import matplotlib.pyplot as plt
  885. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  886. # ax[0].imshow(img[:, :, ::-1]) # base
  887. # ax[1].imshow(img2[:, :, ::-1]) # warped
  888. # Transform label coordinates
  889. n = len(targets)
  890. if n:
  891. use_segments = any(x.any() for x in segments)
  892. new = np.zeros((n, 4))
  893. if use_segments: # warp segments
  894. segments = resample_segments(segments) # upsample
  895. for i, segment in enumerate(segments):
  896. xy = np.ones((len(segment), 3))
  897. xy[:, :2] = segment
  898. xy = xy @ M.T # transform
  899. xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
  900. # clip
  901. new[i] = segment2box(xy, width, height)
  902. else: # warp boxes
  903. xy = np.ones((n * 4, 3))
  904. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  905. xy = xy @ M.T # transform
  906. xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
  907. # create new boxes
  908. x = xy[:, [0, 2, 4, 6]]
  909. y = xy[:, [1, 3, 5, 7]]
  910. new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  911. # clip
  912. new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
  913. new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
  914. # filter candidates
  915. i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
  916. targets = targets[i]
  917. targets[:, 1:5] = new[i]
  918. return img, targets
  919. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
  920. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  921. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  922. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  923. ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
  924. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
  925. def bbox_ioa(box1, box2):
  926. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  927. box2 = box2.transpose()
  928. # Get the coordinates of bounding boxes
  929. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  930. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  931. # Intersection area
  932. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  933. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  934. # box2 area
  935. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  936. # Intersection over box2 area
  937. return inter_area / box2_area
  938. def cutout(image, labels):
  939. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  940. h, w = image.shape[:2]
  941. # create random masks
  942. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  943. for s in scales:
  944. mask_h = random.randint(1, int(h * s))
  945. mask_w = random.randint(1, int(w * s))
  946. # box
  947. xmin = max(0, random.randint(0, w) - mask_w // 2)
  948. ymin = max(0, random.randint(0, h) - mask_h // 2)
  949. xmax = min(w, xmin + mask_w)
  950. ymax = min(h, ymin + mask_h)
  951. # apply random color mask
  952. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  953. # return unobscured labels
  954. if len(labels) and s > 0.03:
  955. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  956. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  957. labels = labels[ioa < 0.60] # remove >60% obscured labels
  958. return labels
  959. def pastein(image, labels, sample_labels, sample_images, sample_masks):
  960. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  961. h, w = image.shape[:2]
  962. # create random masks
  963. scales = [0.75] * 2 + [0.5] * 4 + [0.25] * 4 + [0.125] * 4 + [0.0625] * 6 # image size fraction
  964. for s in scales:
  965. if random.random() < 0.2:
  966. continue
  967. mask_h = random.randint(1, int(h * s))
  968. mask_w = random.randint(1, int(w * s))
  969. # box
  970. xmin = max(0, random.randint(0, w) - mask_w // 2)
  971. ymin = max(0, random.randint(0, h) - mask_h // 2)
  972. xmax = min(w, xmin + mask_w)
  973. ymax = min(h, ymin + mask_h)
  974. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  975. if len(labels):
  976. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  977. else:
  978. ioa = np.zeros(1)
  979. if (ioa < 0.30).all() and len(sample_labels) and (xmax > xmin + 20) and (
  980. ymax > ymin + 20): # allow 30% obscuration of existing labels
  981. sel_ind = random.randint(0, len(sample_labels) - 1)
  982. # print(len(sample_labels))
  983. # print(sel_ind)
  984. # print((xmax-xmin, ymax-ymin))
  985. # print(image[ymin:ymax, xmin:xmax].shape)
  986. # print([[sample_labels[sel_ind], *box]])
  987. # print(labels.shape)
  988. hs, ws, cs = sample_images[sel_ind].shape
  989. r_scale = min((ymax - ymin) / hs, (xmax - xmin) / ws)
  990. r_w = int(ws * r_scale)
  991. r_h = int(hs * r_scale)
  992. if (r_w > 10) and (r_h > 10):
  993. r_mask = cv2.resize(sample_masks[sel_ind], (r_w, r_h))
  994. r_image = cv2.resize(sample_images[sel_ind], (r_w, r_h))
  995. temp_crop = image[ymin:ymin + r_h, xmin:xmin + r_w]
  996. m_ind = r_mask > 0
  997. if m_ind.astype(np.int32).sum() > 60:
  998. temp_crop[m_ind] = r_image[m_ind]
  999. # print(sample_labels[sel_ind])
  1000. # print(sample_images[sel_ind].shape)
  1001. # print(temp_crop.shape)
  1002. box = np.array([xmin, ymin, xmin + r_w, ymin + r_h], dtype=np.float32)
  1003. if len(labels):
  1004. labels = np.concatenate((labels, [[sample_labels[sel_ind], *box]]), 0)
  1005. else:
  1006. labels = np.array([[sample_labels[sel_ind], *box]])
  1007. image[ymin:ymin + r_h, xmin:xmin + r_w] = temp_crop
  1008. return labels
  1009. class Albumentations:
  1010. # YOLOv5 Albumentations class (optional, only used if package is installed)
  1011. def __init__(self):
  1012. self.transform = None
  1013. import albumentations as A
  1014. self.transform = A.Compose([
  1015. A.CLAHE(p=0.01),
  1016. A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.01),
  1017. A.RandomGamma(gamma_limit=[80, 120], p=0.01),
  1018. A.Blur(p=0.01),
  1019. A.MedianBlur(p=0.01),
  1020. A.ToGray(p=0.01),
  1021. A.ImageCompression(quality_lower=75, p=0.01), ],
  1022. bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))
  1023. # logging.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms if x.p))
  1024. def __call__(self, im, labels, p=1.0):
  1025. if self.transform and random.random() < p:
  1026. new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
  1027. im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
  1028. return im, labels
  1029. def create_folder(path='./new'):
  1030. # Create folder
  1031. if os.path.exists(path):
  1032. shutil.rmtree(path) # delete output folder
  1033. os.makedirs(path) # make new output folder
  1034. def flatten_recursive(path='../coco'):
  1035. # Flatten a recursive directory by bringing all files to top level
  1036. new_path = Path(path + '_flat')
  1037. create_folder(new_path)
  1038. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  1039. shutil.copyfile(file, new_path / Path(file).name)
  1040. def extract_boxes(path='../coco/'): # from utils.datasets import *; extract_boxes('../coco128')
  1041. # Convert detection dataset into classification dataset, with one directory per class
  1042. path = Path(path) # images dir
  1043. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  1044. files = list(path.rglob('*.*'))
  1045. n = len(files) # number of files
  1046. for im_file in tqdm(files, total=n):
  1047. if im_file.suffix[1:] in img_formats:
  1048. # image
  1049. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  1050. h, w = im.shape[:2]
  1051. # labels
  1052. lb_file = Path(img2label_paths([str(im_file)])[0])
  1053. if Path(lb_file).exists():
  1054. with open(lb_file, 'r') as f:
  1055. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  1056. for j, x in enumerate(lb):
  1057. c = int(x[0]) # class
  1058. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  1059. if not f.parent.is_dir():
  1060. f.parent.mkdir(parents=True)
  1061. b = x[1:] * [w, h, w, h] # box
  1062. # b[2:] = b[2:].max() # rectangle to square
  1063. b[2:] = b[2:] * 1.2 + 3 # pad
  1064. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  1065. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  1066. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  1067. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  1068. def autosplit(path='../coco', weights=(0.9, 0.1, 0.0), annotated_only=False):
  1069. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  1070. Usage: from utils.datasets import *; autosplit('../coco')
  1071. Arguments
  1072. path: Path to images directory
  1073. weights: Train, val, test weights (list)
  1074. annotated_only: Only use images with an annotated txt file
  1075. """
  1076. path = Path(path) # images dir
  1077. files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only
  1078. n = len(files) # number of files
  1079. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  1080. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  1081. [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing
  1082. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  1083. for i, img in tqdm(zip(indices, files), total=n):
  1084. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  1085. with open(path / txt[i], 'a') as f:
  1086. f.write(str(img) + '\n') # add image to txt file
  1087. def load_segmentations(self, index):
  1088. key = '/work/handsomejw66/coco17/' + self.img_files[index]
  1089. # print(key)
  1090. # /work/handsomejw66/coco17/
  1091. return self.segs[key]

用摄像头检测,代码如下其中detect_onnx.py是最上面的修改的detect源文件,source这里博主用的usb摄像头,id为700。换成视频或图像直接把700换成视频或图像的地址即可

python detect_onnx.py --weights weights/yolov7.onnx --conf 0.25 --img-size 640 --source 700

不知道摄像头id可以简单写个函数暴力破解,这里博主也是看的别人的代码,但已经忘记地址。作者看到可以评论告诉我,我引用您。ok,直接放代码,博主最后id=700

  1. import cv2
  2. id=0
  3. while True:
  4. cap = cv2.VideoCapture(id)
  5. ret, frame = cap.read()
  6. if not ret:
  7. id += 1
  8. print(id)
  9. else:
  10. print("final id =",id)
  11. break

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/929518
推荐阅读
相关标签
  

闽ICP备14008679号