当前位置:   article > 正文

YOLOV5代码datasets.py文件解读_yolov5-5.0 datasets.py

yolov5-5.0 datasets.py

YOLOV5源码的下载:

git clone https://github.com/ultralytics/yolov5.git

YOLOV5代码datasets.py文件解读:

  1. import glob
  2. import os
  3. import random
  4. import shutil
  5. import time
  6. from pathlib import Path
  7. from threading import Thread
  8. import cv2
  9. import math
  10. import numpy as np
  11. import torch
  12. from PIL import Image, ExifTags
  13. from torch.utils.data import Dataset
  14. from tqdm import tqdm
  15. from utils.general import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
  16. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  17. # 支持的图像格式
  18. img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
  19. # 支持的视频格式
  20. vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
  21. # Get orientation exif tag
  22. '''
  23. 可交换图像文件格式(Exchangeable image file format,简称Exif),
  24. 是专门为数码相机的照片设定的,可以记录数码照片的属性信息和拍摄数据。
  25. '''
  26. for orientation in ExifTags.TAGS.keys():
  27. if ExifTags.TAGS[orientation] == 'Orientation':
  28. break
  29. # 返回文件列表的hash值
  30. def get_hash(files):
  31. # Returns a single hash value of a list of files
  32. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  33. # 获取图片的宽、高信息
  34. # check Exif Orientation metadata and rotate the images if needed.
  35. def exif_size(img):
  36. # Returns exif-corrected PIL size
  37. s = img.size # (width, height)
  38. try:
  39. rotation = dict(img._getexif().items())[orientation] # 调整数码相机照片方向
  40. if rotation == 6: # rotation 270
  41. s = (s[1], s[0])
  42. elif rotation == 8: # rotation 90
  43. s = (s[1], s[0])
  44. except:
  45. pass
  46. return s
  47. # 利用自定义的数据集(LoadImagesAndLabels)创建dataloader
  48. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  49. rank=-1, world_size=1, workers=8):
  50. """
  51. 参数解析:
  52. path:包含图片路径的txt文件或者包含图片的文件夹路径
  53. imgsz:网络输入图片大小
  54. batch_size: 批次大小
  55. stride:网络下采样步幅
  56. opt:调用train.py时传入的参数,这里主要用到opt.single_cls,是否是单类数据集
  57. hyp:网络训练时的一些超参数,包括学习率等,这里主要用到里面一些关于数据增强(旋转、平移等)的系数
  58. augment:是否进行数据增强(Mosaic以外)
  59. cache:是否提前缓存图片到内存,以便加快训练速度
  60. pad:设置矩形训练的shape时进行的填充
  61. rect:是否进行ar排序矩形训练(为True不做Mosaic数据增强)
  62. """
  63. # Make sure only the first process in DDP(DistributedDataParallel) process the dataset first,
  64. # and the following others can use the cache.
  65. with torch_distributed_zero_first(rank):
  66. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  67. augment=augment, # augment images
  68. hyp=hyp, # augmentation hyperparameters
  69. rect=rect, # rectangular training
  70. cache_images=cache,
  71. single_cls=opt.single_cls,
  72. stride=int(stride),
  73. pad=pad,
  74. rank=rank)
  75. batch_size = min(batch_size, len(dataset))
  76. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  77. # 给每个rank对应的进程分配训练的样本索引
  78. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  79. # 实例化InfiniteDataLoader
  80. dataloader = InfiniteDataLoader(dataset,
  81. batch_size=batch_size,
  82. num_workers=nw,
  83. sampler=sampler,
  84. pin_memory=True,
  85. collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
  86. return dataloader, dataset
  87. # Dataloader takes a chunk of time at the start of every epoch to start worker processes.
  88. # We only need to initialize it once at first epoch through this InfiniteDataLoader class
  89. # which subclasses the DataLoader class.
  90. # 定义DataLoader(一个python生成器)
  91. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  92. """ Dataloader that reuses workers.
  93. Uses same syntax as vanilla DataLoader.
  94. """
  95. def __init__(self, *args, **kwargs):
  96. super().__init__(*args, **kwargs)
  97. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  98. self.iterator = super().__iter__()
  99. def __len__(self):
  100. return len(self.batch_sampler.sampler)
  101. def __iter__(self): # 实现了__iter__方法的对象是可迭代的
  102. for i in range(len(self)):
  103. yield next(self.iterator)
  104. # 定义生成器 _RepeatSampler
  105. class _RepeatSampler(object):
  106. """ Sampler that repeats forever.
  107. Args:
  108. sampler (Sampler)
  109. """
  110. def __init__(self, sampler):
  111. self.sampler = sampler
  112. def __iter__(self):
  113. while True:
  114. yield from iter(self.sampler)
  115. # 定义迭代器 LoadImages;用于detect.py
  116. class LoadImages: # for inference
  117. def __init__(self, path, img_size=640):
  118. p = str(Path(path)) # os-agnostic
  119. # os.path.abspath(p)返回p的绝对路径
  120. p = os.path.abspath(p) # absolute path;完整路径
  121. # 如果采用正则化表达式提取图片/视频,可使用glob获取文件路径
  122. if '*' in p:
  123. files = sorted(glob.glob(p, recursive=True)) # glob
  124. elif os.path.isdir(p): # 如果path是一个文件夹,使用glob获取全部文件路径
  125. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  126. elif os.path.isfile(p): # 如果是文件则直接获取
  127. files = [p] # files
  128. else:
  129. raise Exception('ERROR: %s does not exist' % p)
  130. # os.path.splitext分离文件名和后缀(后缀包含.)
  131. # 分别提取图片和视频文件路径
  132. images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
  133. videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
  134. # 获得图片与视频数量
  135. ni, nv = len(images), len(videos)
  136. self.img_size = img_size # 输入图片size
  137. self.files = images + videos # 整合图片和视频路径到一个列表
  138. self.nf = ni + nv # number of files;总的文件数量
  139. # 设置判断是否为视频的bool变量,方便后面单独对视频进行处理
  140. self.video_flag = [False] * ni + [True] * nv
  141. # 初始化模块信息,代码中对于mode=images与mode=video有不同处理
  142. self.mode = 'images'
  143. if any(videos): # 如果包含视频文件,则初始化opencv中的视频模块,cap=cv2.VideoCapture等
  144. self.new_video(videos[0]) # new video
  145. else:
  146. self.cap = None
  147. # nf如果小于0,则打印提示信息
  148. assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
  149. (p, img_formats, vid_formats)
  150. def __iter__(self):
  151. self.count = 0
  152. return self
  153. def __next__(self):
  154. if self.count == self.nf: # self.count == self.nf表示数据读取完了
  155. raise StopIteration
  156. path = self.files[self.count] # 获取文件路径
  157. if self.video_flag[self.count]: # 如果该文件为视频
  158. # Read video
  159. self.mode = 'video' # 修改mode为video
  160. ret_val, img0 = self.cap.read() # 获取当前帧画面,ret_val为一个bool变量,直到视频读取完毕之前都为True
  161. if not ret_val: # 如果当前视频读取结束,则读取下一个视频
  162. self.count += 1
  163. self.cap.release() # 释放视频对象
  164. if self.count == self.nf: # last video; self.count == self.nf表示视频已经读取完了
  165. raise StopIteration
  166. else:
  167. path = self.files[self.count]
  168. self.new_video(path)
  169. ret_val, img0 = self.cap.read()
  170. self.frame += 1 # 当前读取的帧数
  171. print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')
  172. else:
  173. # Read image
  174. self.count += 1
  175. img0 = cv2.imread(path) # BGR格式
  176. assert img0 is not None, 'Image Not Found ' + path
  177. print('image %g/%g %s: ' % (self.count, self.nf, path), end='')
  178. # Padded resize
  179. img = letterbox(img0, new_shape=self.img_size)[0] # 对图片进行resize+pad
  180. # Convert
  181. # opencv读入的图像BGR->RGB操作; BGR转为RGB格式,并且把channel轴换到前面
  182. # img[:,:,::-1]的作用就是实现RGB到BGR通道的转换;对于列表img进行img[:,:,::-1]的作用是列表数组左右翻转
  183. # torch.Tensor 高维矩阵的表示: (nSample)x C x H x W
  184. # numpy.ndarray 高维矩阵的表示: H x W x C
  185. # 把channel轴换到前面使用transpose() 方法 。
  186. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x640x640
  187. img = np.ascontiguousarray(img) # 将数组内存转为连续,提高运行速度
  188. # cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
  189. return path, img, img0, self.cap # 返回:路径,resize+pad的图片,原始图片,视频对象
  190. def new_video(self, path):
  191. self.frame = 0 # frame用来记录帧数
  192. self.cap = cv2.VideoCapture(path) # 初始化视频对象
  193. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 视频文件中的总帧数
  194. def __len__(self):
  195. return self.nf # number of files
  196. # 定义迭代器 LoadWebcam; 未使用
  197. class LoadWebcam: # for inference
  198. def __init__(self, pipe=0, img_size=640):
  199. self.img_size = img_size
  200. if pipe == '0':
  201. pipe = 0 # local camera
  202. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  203. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  204. # pipe = 'rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa' # IP traffic camera
  205. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  206. # https://answers.opencv.org/question/215996/changing-gstreamer-pipeline-to-opencv-in-pythonsolved/
  207. # pipe = '"rtspsrc location="rtsp://username:password@192.168.1.64/1" latency=10 ! appsink' # GStreamer
  208. # https://answers.opencv.org/question/200787/video-acceleration-gstremer-pipeline-in-videocapture/
  209. # https://stackoverflow.com/questions/54095699/install-gstreamer-support-for-opencv-python-package # install help
  210. # pipe = "rtspsrc location=rtsp://root:root@192.168.0.91:554/axis-media/media.amp?videocodec=h264&resolution=3840x2160 protocols=GST_RTSP_LOWER_TRANS_TCP ! rtph264depay ! queue ! vaapih264dec ! videoconvert ! appsink" # GStreamer
  211. self.pipe = pipe
  212. self.cap = cv2.VideoCapture(pipe) # video capture object
  213. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  214. def __iter__(self):
  215. self.count = -1
  216. return self
  217. def __next__(self):
  218. self.count += 1
  219. if cv2.waitKey(1) == ord('q'): # q to quit
  220. self.cap.release()
  221. cv2.destroyAllWindows()
  222. raise StopIteration
  223. # Read frame
  224. if self.pipe == 0: # local camera
  225. ret_val, img0 = self.cap.read() # cap.read() 结合grab和retrieve的功能,抓取下一帧并解码
  226. img0 = cv2.flip(img0, 1) # flip left-right
  227. else: # IP camera
  228. n = 0
  229. while True:
  230. n += 1
  231. self.cap.grab() # cap.grab()从设备或视频获取下一帧
  232. if n % 30 == 0: # skip frames
  233. ret_val, img0 = self.cap.retrieve() # cap.retrieve() 在grab后使用,对获取到的帧进行解码
  234. if ret_val:
  235. break
  236. # Print
  237. assert ret_val, 'Camera Error %s' % self.pipe
  238. img_path = 'webcam.jpg'
  239. print('webcam %g: ' % self.count, end='')
  240. # Padded resize
  241. img = letterbox(img0, new_shape=self.img_size)[0]
  242. # Convert
  243. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x640x640
  244. img = np.ascontiguousarray(img)
  245. return img_path, img, img0, None
  246. def __len__(self):
  247. return 0
  248. # 定义迭代器 LoadStreams;用于detect.py
  249. """
  250. cv2视频读取函数:
  251. cap.grap() 从设备或视频获取下一帧,获取成功返回true否则false
  252. cap.retrieve(frame) 在grab后使用,对获取到的帧进行解码,也返回true或false
  253. cap.read(frame) 结合grab和retrieve的功能,抓取下一帧并解码
  254. """
  255. class LoadStreams: # multiple IP or RTSP cameras
  256. def __init__(self, sources='streams.txt', img_size=640):
  257. self.mode = 'images' # 初始化mode为images
  258. self.img_size = img_size
  259. # 如果sources为一个保存了多个视频流的文件
  260. # 获取每一个视频流,保存为一个列表
  261. if os.path.isfile(sources):
  262. with open(sources, 'r') as f:
  263. sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
  264. else:
  265. sources = [sources]
  266. n = len(sources)
  267. self.imgs = [None] * n
  268. self.sources = sources # 视频流个数
  269. for i, s in enumerate(sources):
  270. # Start the thread to read frames from the video stream
  271. print('%g/%g: %s... ' % (i + 1, n, s), end='') # 打印当前视频,总视频数,视频流地址
  272. cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s) # 如果source=0则打开摄像头,否则打开视频流地址
  273. assert cap.isOpened(), 'Failed to open %s' % s
  274. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 获取视频的宽度信息
  275. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 获取视频的高度信息
  276. fps = cap.get(cv2.CAP_PROP_FPS) % 100 # 获取视频的帧率
  277. _, self.imgs[i] = cap.read() # guarantee first frame;读取当前画面
  278. # 创建多线程读取视频流,daemon=True表示主线程结束时子线程也结束
  279. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  280. print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
  281. thread.start()
  282. print('') # newline
  283. # check for common shapes
  284. # 获取进行resize+pad之后的shape,letterbox函数默认(参数auto=True)是按照矩形推理形状进行填充
  285. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  286. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  287. if not self.rect:
  288. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  289. def update(self, index, cap):
  290. # Read next stream frame in a daemon thread
  291. n = 0
  292. while cap.isOpened():
  293. n += 1
  294. # _, self.imgs[index] = cap.read()
  295. cap.grab()
  296. if n == 4: # read every 4th frame; 每4帧读取一次
  297. _, self.imgs[index] = cap.retrieve()
  298. n = 0
  299. time.sleep(0.01) # wait time
  300. def __iter__(self):
  301. self.count = -1
  302. return self
  303. def __next__(self):
  304. self.count += 1
  305. img0 = self.imgs.copy()
  306. if cv2.waitKey(1) == ord('q'): # q to quit
  307. cv2.destroyAllWindows()
  308. raise StopIteration
  309. # Letterbox
  310. # 对图片进行resize+pad
  311. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  312. # Stack
  313. img = np.stack(img, 0) # 将读取的图片拼接到一起
  314. # Convert
  315. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x640x640
  316. img = np.ascontiguousarray(img)
  317. return self.sources, img, img0, None
  318. def __len__(self):
  319. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  320. # 自定义的数据集
  321. # 定义LoadImagesAndLabels类, 继承Dataset, 重写抽象方法:__len()__, __getitem()__
  322. class LoadImagesAndLabels(Dataset): # for training/testing
  323. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  324. cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
  325. self.img_size = img_size # 输入图片分辨率大小
  326. self.augment = augment # 数据增强
  327. self.hyp = hyp # 超参数
  328. self.image_weights = image_weights # 图片采样权重
  329. self.rect = False if image_weights else rect # 矩形训练
  330. # mosaic数据增强
  331. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  332. # mosaic增强的边界值
  333. self.mosaic_border = [-img_size // 2, -img_size // 2]
  334. self.stride = stride # 模型下采样的步长
  335. def img2label_paths(img_paths):
  336. # Define label paths as a function of image paths
  337. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  338. return [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in img_paths]
  339. try:
  340. f = [] # image files
  341. for p in path if isinstance(path, list) else [path]:
  342. # 获取数据集路径path,包含图片路径的txt文件或者包含图片的文件夹路径
  343. # 使用pathlib.Path生成与操作系统无关的路径,因为不同操作系统路径的‘/’会有所不同
  344. p = str(Path(p)) # os-agnostic
  345. parent = str(Path(p).parent) + os.sep # 获取数据集路径的上级父目录,os.sep为路径里的分隔符(不同系统路径分隔符不同,os.sep根据系统自适应)
  346. # 系统路径中的分隔符:Windows系统通过是“\\”,Linux类系统如Ubuntu的分隔符是“/”,而苹果Mac OS系统中是“:”。
  347. if os.path.isfile(p): # file; 如果路径path为包含图片路径的txt文件
  348. with open(p, 'r') as t:
  349. t = t.read().splitlines() # 获取图片路径,更换相对路径
  350. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  351. elif os.path.isdir(p): # folder; 如果路径path为包含图片的文件夹路径
  352. f += glob.iglob(p + os.sep + '*.*')
  353. # glob.iglob() 函数获取一个可遍历对象,使用它可以逐个获取匹配的文件路径名。
  354. # 与glob.glob()的区别是:glob.glob()可同时获取所有的匹配路径,而glob.iglob()一次只能获取一个匹配路径。
  355. else:
  356. raise Exception('%s does not exist' % p)
  357. # 分隔符替换为os.sep,os.path.splitext(x)将文件名与扩展名分开并返回一个列表
  358. self.img_files = sorted(
  359. [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
  360. assert len(self.img_files) > 0, 'No images found'
  361. except Exception as e:
  362. raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
  363. # Check cache
  364. self.label_files = img2label_paths(self.img_files) # labels
  365. cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
  366. if os.path.isfile(cache_path):
  367. cache = torch.load(cache_path) # load
  368. if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
  369. cache = self.cache_labels(cache_path) # re-cache
  370. else:
  371. cache = self.cache_labels(cache_path) # cache
  372. # Read cache
  373. cache.pop('hash') # remove hash
  374. labels, shapes = zip(*cache.values())
  375. self.labels = list(labels)
  376. self.shapes = np.array(shapes, dtype=np.float64)
  377. self.img_files = list(cache.keys()) # update
  378. self.label_files = img2label_paths(cache.keys()) # update
  379. n = len(shapes) # number of images 数据集的图片文件数量
  380. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index 获取batch的索引
  381. nb = bi[-1] + 1 # number of batches: 一个epoch(轮次)batch的数量
  382. self.batch = bi # batch index of image
  383. self.n = n
  384. # ar排序矩形训练
  385. # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
  386. if self.rect:
  387. # Sort by aspect ratio
  388. s = self.shapes # wh
  389. ar = s[:, 1] / s[:, 0] # aspect ratio
  390. irect = ar.argsort() # 获取根据ar从小到大排序的索引
  391. # 根据索引排序数据集与标签路径、shape、h/w
  392. self.img_files = [self.img_files[i] for i in irect]
  393. self.label_files = [self.label_files[i] for i in irect]
  394. self.labels = [self.labels[i] for i in irect]
  395. self.shapes = s[irect] # wh
  396. ar = ar[irect]
  397. # Set training image shapes
  398. shapes = [[1, 1]] * nb # 初始化shapes,nb为一轮批次batch的数量
  399. for i in range(nb):
  400. ari = ar[bi == i]
  401. mini, maxi = ari.min(), ari.max()
  402. if maxi < 1: # 如果一个batch中最大的h/w小于1,则此batch的shape为(img_size*maxi, img_size)
  403. shapes[i] = [maxi, 1]
  404. elif mini > 1: # 如果一个batch中最小的h/w大于1,则此batch的shape为(img_size, img_size/mini)
  405. shapes[i] = [1, 1 / mini]
  406. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  407. # Check labels
  408. create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
  409. nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
  410. pbar = enumerate(self.label_files)
  411. if rank in [-1, 0]:
  412. pbar = tqdm(pbar)
  413. for i, file in pbar:
  414. l = self.labels[i] # label
  415. if l is not None and l.shape[0]:
  416. assert l.shape[1] == 5, '> 5 label columns: %s' % file
  417. assert (l >= 0).all(), 'negative labels: %s' % file
  418. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
  419. if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
  420. nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
  421. if single_cls:
  422. l[:, 0] = 0 # force dataset into single-class mode
  423. self.labels[i] = l
  424. nf += 1 # file found
  425. # Create subdataset (a smaller dataset)
  426. if create_datasubset and ns < 1E4:
  427. if ns == 0:
  428. create_folder(path='./datasubset')
  429. os.makedirs('./datasubset/images')
  430. exclude_classes = 43
  431. if exclude_classes not in l[:, 0]:
  432. ns += 1
  433. # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
  434. with open('./datasubset/images.txt', 'a') as f:
  435. f.write(self.img_files[i] + '\n')
  436. # Extract object detection boxes for a second stage classifier
  437. if extract_bounding_boxes:
  438. p = Path(self.img_files[i])
  439. img = cv2.imread(str(p))
  440. h, w = img.shape[:2]
  441. for j, x in enumerate(l):
  442. f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
  443. if not os.path.exists(Path(f).parent):
  444. os.makedirs(Path(f).parent) # make new output folder
  445. b = x[1:] * [w, h, w, h] # box
  446. b[2:] = b[2:].max() # rectangle to square
  447. b[2:] = b[2:] * 1.3 + 30 # pad
  448. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  449. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  450. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  451. assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
  452. else:
  453. ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
  454. # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
  455. if rank in [-1, 0]:
  456. pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
  457. cache_path, nf, nm, ne, nd, n)
  458. if nf == 0: # No labels found
  459. s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
  460. print(s)
  461. assert not augment, '%s. Can not train without labels.' % s
  462. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  463. # 初始化图片与标签,为缓存图片、标签做准备
  464. self.imgs = [None] * n
  465. if cache_images:
  466. gb = 0 # Gigabytes of cached images
  467. pbar = tqdm(range(len(self.img_files)), desc='Caching images')
  468. self.img_hw0, self.img_hw = [None] * n, [None] * n
  469. for i in pbar: # max 10k images
  470. self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i) # img, hw_original, hw_resized
  471. gb += self.imgs[i].nbytes
  472. pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
  473. # 缓存标签
  474. def cache_labels(self, path='labels.cache'):
  475. # Cache dataset labels, check images and read shapes
  476. x = {} # dict
  477. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  478. for (img, label) in pbar:
  479. try:
  480. l = []
  481. im = Image.open(img)
  482. im.verify() # PIL verify
  483. shape = exif_size(im) # image size
  484. assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
  485. if os.path.isfile(label):
  486. with open(label, 'r') as f:
  487. l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
  488. if len(l) == 0:
  489. l = np.zeros((0, 5), dtype=np.float32)
  490. x[img] = [l, shape]
  491. except Exception as e:
  492. print('WARNING: Ignoring corrupted image and/or label %s: %s' % (img, e))
  493. x['hash'] = get_hash(self.label_files + self.img_files)
  494. torch.save(x, path) # save for next time
  495. return x
  496. def __len__(self):
  497. return len(self.img_files)
  498. # def __iter__(self):
  499. # self.count = -1
  500. # print('ran dataset iter')
  501. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  502. # return self
  503. def __getitem__(self, index):
  504. if self.image_weights: # 如果存在image_weights,则获取新的下标
  505. index = self.indices[index]
  506. """
  507. self.indices在train.py中设置, 要配合着train.py中的代码使用
  508. image_weights为根据标签中每个类别的数量设置的图片采样权重
  509. 如果image_weights=True,则根据图片采样权重获取新的下标
  510. """
  511. hyp = self.hyp # 超参数
  512. mosaic = self.mosaic and random.random() < hyp['mosaic']
  513. # image mosaic (probability),默认为1
  514. if mosaic:
  515. # Load mosaic
  516. img, labels = load_mosaic(self, index) # 使用mosaic数据增强方式加载图片和标签
  517. shapes = None
  518. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  519. # Mixup数据增强
  520. if random.random() < hyp['mixup']:
  521. img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
  522. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  523. img = (img * r + img2 * (1 - r)).astype(np.uint8) # mixup
  524. labels = np.concatenate((labels, labels2), 0)
  525. else:
  526. # Load image 加载图片并根据设定的输入大小与图片原大小的比例ratio进行resize
  527. img, (h0, w0), (h, w) = load_image(self, index)
  528. # Letterbox
  529. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  530. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  531. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  532. # Load labels
  533. labels = []
  534. x = self.labels[index]
  535. if x.size > 0:
  536. # Normalized xywh to pixel xyxy format
  537. # 根据pad调整框的标签坐标,并从归一化的xywh->未归一化的xyxy
  538. labels = x.copy()
  539. labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
  540. labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
  541. labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
  542. labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
  543. if self.augment:
  544. # Augment imagespace
  545. if not mosaic: # 需要做数据增强但没使用mosaic: 则随机对图片进行旋转,平移,缩放,裁剪
  546. img, labels = random_perspective(img, labels,
  547. degrees=hyp['degrees'],
  548. translate=hyp['translate'],
  549. scale=hyp['scale'],
  550. shear=hyp['shear'],
  551. perspective=hyp['perspective'])
  552. # Augment colorspace # 随机改变图片的色调(H),饱和度(S),亮度(V)
  553. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  554. # Apply cutouts
  555. # if random.random() < 0.9:
  556. # labels = cutout(img, labels)
  557. nL = len(labels) # number of labels
  558. if nL: # 调整框的标签,xyxy->xywh(归一化)
  559. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  560. # 重新归一化标签0 - 1
  561. labels[:, [2, 4]] /= img.shape[0] # normalized height 0~1
  562. labels[:, [1, 3]] /= img.shape[1] # normalized width 0~1
  563. if self.augment:
  564. # flip up-down # 图片随机上下翻转
  565. if random.random() < hyp['flipud']:
  566. img = np.flipud(img)
  567. if nL:
  568. labels[:, 2] = 1 - labels[:, 2]
  569. # flip left-right # 图片随机左右翻转
  570. if random.random() < hyp['fliplr']:
  571. img = np.fliplr(img)
  572. if nL:
  573. labels[:, 1] = 1 - labels[:, 1]
  574. # 初始化标签框对应的图片序号,配合下面的collate_fn使用
  575. labels_out = torch.zeros((nL, 6))
  576. if nL:
  577. labels_out[:, 1:] = torch.from_numpy(labels)
  578. # Convert
  579. # img[:,:,::-1]的作用就是实现BGR到RGB通道的转换; 对于列表img进行img[:,:,::-1]的作用是列表数组左右翻转
  580. # channel轴换到前面
  581. # torch.Tensor 高维矩阵的表示: (nSample)x C x H x W
  582. # numpy.ndarray 高维矩阵的表示: H x W x C
  583. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x640x640
  584. img = np.ascontiguousarray(img)
  585. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  586. # pytorch的DataLoader打包一个batch的数据集时要经过函数collate_fn进行打包
  587. # 例如:通过重写此函数实现标签与图片对应的划分,一个batch中哪些标签属于哪一张图片
  588. @staticmethod
  589. def collate_fn(batch): # 整理函数:如何取样本的,可以定义自己的函数来实现想要的功能
  590. img, label, path, shapes = zip(*batch) # transposed
  591. for i, l in enumerate(label):
  592. l[:, 0] = i # add target image index for build_targets()
  593. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  594. # Ancillary functions --------------------------------------------------------------------------------------------------
  595. # load_image加载图片并根据设定的输入大小与图片原大小的比例ratio进行resize
  596. def load_image(self, index):
  597. # loads 1 image from dataset, returns img, original hw, resized hw
  598. img = self.imgs[index]
  599. if img is None: # not cached
  600. path = self.img_files[index]
  601. img = cv2.imread(path) # BGR
  602. assert img is not None, 'Image Not Found ' + path
  603. h0, w0 = img.shape[:2] # orig hw
  604. r = self.img_size / max(h0, w0) # resize image to img_size
  605. # 根据ratio选择不同的插值方式
  606. if r != 1: # always resize down, only resize up if training with augmentation
  607. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  608. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  609. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  610. else:
  611. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  612. # HSV色彩空间做数据增强
  613. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  614. # 随机取-1到1三个实数,乘以hyp中的hsv三通道的系数;HSV(Hue, Saturation, Value)
  615. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  616. # 分离通道
  617. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  618. dtype = img.dtype # uint8
  619. # 随机调整hsv
  620. x = np.arange(0, 256, dtype=np.int16)
  621. lut_hue = ((x * r[0]) % 180).astype(dtype) # 色调H
  622. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) # 饱和度S
  623. lut_val = np.clip(x * r[2], 0, 255).astype(dtype) # 明度V
  624. # 随机调整hsv之后重新组合通道
  625. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  626. # 将hsv格式转为BGR格式
  627. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  628. # Histogram equalization
  629. # if random.random() < 0.2:
  630. # for i in range(3):
  631. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  632. # 生成一个mosaic增强的图片
  633. def load_mosaic(self, index):
  634. # loads images in a mosaic
  635. labels4 = []
  636. s = self.img_size
  637. # 随机取mosaic中心点
  638. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  639. # 随机取其它三张图片的索引
  640. indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
  641. for i, index in enumerate(indices):
  642. # Load image
  643. # load_image加载图片并根据设定的输入大小与图片原大小的比例ratio进行resize
  644. img, _, (h, w) = load_image(self, index)
  645. # place img in img4
  646. if i == 0: # top left
  647. # 初始化大图
  648. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  649. # 设置大图上的位置(左上角)
  650. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  651. # 选取小图上的位置
  652. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  653. elif i == 1: # top right 右上角
  654. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  655. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  656. elif i == 2: # bottom left 左下角
  657. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  658. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
  659. elif i == 3: # bottom right 右下角
  660. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  661. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  662. # 将小图上截取的部分贴到大图上
  663. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  664. # 计算小图到大图上时所产生的偏移,用来计算mosaic增强后的标签框的位置
  665. padw = x1a - x1b
  666. padh = y1a - y1b
  667. # Labels
  668. x = self.labels[index]
  669. labels = x.copy()
  670. # 重新调整标签框的位置
  671. if x.size > 0: # Normalized xywh to pixel xyxy format
  672. labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
  673. labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
  674. labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
  675. labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
  676. labels4.append(labels)
  677. # Concat/clip labels
  678. if len(labels4):
  679. # 调整标签框在图片内部
  680. labels4 = np.concatenate(labels4, 0)
  681. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective
  682. # img4, labels4 = replicate(img4, labels4) # replicate
  683. # 进行mosaic的时候将四张图片整合到一起之后shape为[2*img_size, 2*img_size]
  684. # 对mosaic整合的图片进行随机旋转、平移、缩放、裁剪,并resize为输入大小img_size
  685. # Augment
  686. img4, labels4 = random_perspective(img4, labels4,
  687. degrees=self.hyp['degrees'],
  688. translate=self.hyp['translate'],
  689. scale=self.hyp['scale'],
  690. shear=self.hyp['shear'],
  691. perspective=self.hyp['perspective'],
  692. border=self.mosaic_border) # border to remove
  693. return img4, labels4
  694. def replicate(img, labels):
  695. # Replicate labels
  696. h, w = img.shape[:2]
  697. boxes = labels[:, 1:].astype(int)
  698. x1, y1, x2, y2 = boxes.T
  699. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  700. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  701. x1b, y1b, x2b, y2b = boxes[i]
  702. bh, bw = y2b - y1b, x2b - x1b
  703. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  704. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  705. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  706. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  707. return img, labels
  708. # 图像缩放: 保持图片的宽高比例,剩下的部分采用灰色填充。
  709. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  710. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  711. shape = img.shape[:2] # current shape [height, width]
  712. if isinstance(new_shape, int):
  713. new_shape = (new_shape, new_shape)
  714. # Scale ratio (new / old) # 计算缩放因子
  715. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  716. """
  717. 缩放(resize)到输入大小img_size的时候,如果没有设置上采样的话,则只进行下采样
  718. 因为上采样图片会让图片模糊,对训练不友好影响性能。
  719. """
  720. if not scaleup: # only scale down, do not scale up (for better test mAP)
  721. r = min(r, 1.0)
  722. # Compute padding
  723. ratio = r, r # width, height ratios
  724. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  725. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  726. if auto: # minimum rectangle # 获取最小的矩形填充
  727. dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
  728. # 如果scaleFill=True,则不进行填充,直接resize成img_size, 任由图片进行拉伸和压缩
  729. elif scaleFill: # stretch
  730. dw, dh = 0.0, 0.0
  731. new_unpad = (new_shape[1], new_shape[0])
  732. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  733. # 计算上下左右填充大小
  734. dw /= 2 # divide padding into 2 sides
  735. dh /= 2
  736. if shape[::-1] != new_unpad: # resize
  737. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  738. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  739. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  740. # 进行填充
  741. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  742. return img, ratio, (dw, dh)
  743. # 随机透视变换
  744. # 计算方法为坐标向量和变换矩阵的乘积
  745. def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)):
  746. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  747. # targets = [cls, xyxy]
  748. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  749. width = img.shape[1] + border[1] * 2
  750. # Center
  751. C = np.eye(3)
  752. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  753. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  754. # Perspective:透视变换
  755. P = np.eye(3)
  756. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  757. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  758. # Rotation and Scale # 设置旋转和缩放的仿射矩阵
  759. R = np.eye(3)
  760. a = random.uniform(-degrees, degrees)
  761. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  762. s = random.uniform(1 - scale, 1 + scale)
  763. # s = 2 ** random.uniform(-scale, scale)
  764. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  765. # Shear;设置裁剪的仿射矩阵系数
  766. S = np.eye(3)
  767. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  768. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  769. # Translation;设置平移的仿射矩阵系数
  770. T = np.eye(3)
  771. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  772. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  773. # Combined rotation matrix
  774. # 融合仿射矩阵并作用在图片上; @表示矩阵乘法运算
  775. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  776. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  777. if perspective:
  778. # 透视变换函数,可保持直线不变形,但是平行线可能不再平行
  779. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  780. else: # affine
  781. # 仿射变换函数,可实现旋转,平移,缩放;变换后的平行线依旧平行
  782. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  783. # Visualize
  784. # import matplotlib.pyplot as plt
  785. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  786. # ax[0].imshow(img[:, :, ::-1]) # base
  787. # ax[1].imshow(img2[:, :, ::-1]) # warped
  788. # Transform label coordinates
  789. # 调整框的标签
  790. n = len(targets)
  791. if n:
  792. # warp points
  793. xy = np.ones((n * 4, 3))
  794. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  795. xy = xy @ M.T # transform
  796. if perspective:
  797. xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
  798. else: # affine
  799. xy = xy[:, :2].reshape(n, 8)
  800. # create new boxes
  801. x = xy[:, [0, 2, 4, 6]]
  802. y = xy[:, [1, 3, 5, 7]]
  803. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  804. # # apply angle-based reduction of bounding boxes
  805. # radians = a * math.pi / 180
  806. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  807. # x = (xy[:, 2] + xy[:, 0]) / 2
  808. # y = (xy[:, 3] + xy[:, 1]) / 2
  809. # w = (xy[:, 2] - xy[:, 0]) * reduction
  810. # h = (xy[:, 3] - xy[:, 1]) * reduction
  811. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  812. # clip boxes
  813. # 去除进行上面一系列操作后被裁剪过小的框;reject warped points outside of image
  814. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  815. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  816. # filter candidates
  817. i = box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T)
  818. targets = targets[i]
  819. targets[:, 1:5] = xy[i]
  820. return img, targets
  821. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n)
  822. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  823. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  824. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  825. ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
  826. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates
  827. # cutout数据增强
  828. def cutout(image, labels):
  829. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  830. h, w = image.shape[:2]
  831. def bbox_ioa(box1, box2):
  832. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  833. box2 = box2.transpose()
  834. # Get the coordinates of bounding boxes
  835. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  836. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  837. # Intersection area
  838. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  839. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  840. # box2 area
  841. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  842. # Intersection over box2 area
  843. return inter_area / box2_area
  844. # create random masks
  845. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  846. for s in scales:
  847. mask_h = random.randint(1, int(h * s))
  848. mask_w = random.randint(1, int(w * s))
  849. # box
  850. xmin = max(0, random.randint(0, w) - mask_w // 2)
  851. ymin = max(0, random.randint(0, h) - mask_h // 2)
  852. xmax = min(w, xmin + mask_w)
  853. ymax = min(h, ymin + mask_h)
  854. # apply random color mask
  855. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  856. # return unobscured labels
  857. if len(labels) and s > 0.03:
  858. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  859. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  860. labels = labels[ioa < 0.60] # remove >60% obscured labels
  861. return labels
  862. def reduce_img_size(path='path/images', img_size=1024): # from utils.datasets import *; reduce_img_size()
  863. # creates a new ./images_reduced folder with reduced size images of maximum size img_size
  864. path_new = path + '_reduced' # reduced images path
  865. create_folder(path_new)
  866. for f in tqdm(glob.glob('%s/*.*' % path)):
  867. try:
  868. img = cv2.imread(f)
  869. h, w = img.shape[:2]
  870. r = img_size / max(h, w) # size ratio
  871. if r < 1.0:
  872. img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_AREA) # _LINEAR fastest
  873. fnew = f.replace(path, path_new) # .replace(Path(f).suffix, '.jpg')
  874. cv2.imwrite(fnew, img)
  875. except:
  876. print('WARNING: image failure %s' % f)
  877. def recursive_dataset2bmp(dataset='path/dataset_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
  878. # Converts dataset to bmp (for faster training)
  879. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  880. for a, b, files in os.walk(dataset):
  881. for file in tqdm(files, desc=a):
  882. p = a + '/' + file
  883. s = Path(file).suffix
  884. if s == '.txt': # replace text
  885. with open(p, 'r') as f:
  886. lines = f.read()
  887. for f in formats:
  888. lines = lines.replace(f, '.bmp')
  889. with open(p, 'w') as f:
  890. f.write(lines)
  891. elif s in formats: # replace image
  892. cv2.imwrite(p.replace(s, '.bmp'), cv2.imread(p))
  893. if s != '.bmp':
  894. os.system("rm '%s'" % p)
  895. def imagelist2folder(path='path/images.txt'): # from utils.datasets import *; imagelist2folder()
  896. # Copies all the images in a text file (list of images) into a folder
  897. create_folder(path[:-4])
  898. with open(path, 'r') as f:
  899. for line in f.read().splitlines():
  900. os.system('cp "%s" %s' % (line, path[:-4]))
  901. print(line)
  902. def create_folder(path='./new'):
  903. # Create folder
  904. if os.path.exists(path):
  905. shutil.rmtree(path) # delete output folder
  906. os.makedirs(path) # make new output folder

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/929517
推荐阅读
相关标签
  

闽ICP备14008679号