赞
踩
2021SC@SDUSC
这篇分析plot.py文件,就如其名称一样,主要是一些用以展示的代码,也不是核心代码
- from copy import copy
- from pathlib import Path
-
- import cv2
- import math
- import matplotlib
- import matplotlib.pyplot as plt
- import numpy as np
- import pandas as pd
- import seaborn as sn
- import torch
- from PIL import Image, ImageDraw, ImageFont
-
- from utils.general import user_config_dir, is_ascii, xywh2xyxy, xyxy2xywh
- from utils.metrics import fitness
copy:用于对象的拷贝操作,该模块只提供了两个主要的方法,cpoy.cpoy与cpoy.deepcopy,分别表示浅复制和深复制
Path,cv2,math,numpy,pandas在general.py中已经介绍过了
matplotlib:是python最著名的绘图库,提供了一整套和matlab相似的命令API,是这个文件的主要外部库
seaborn:基于matplotlib的python可视化库,是在matplotlib的基础上进行了更高级的API封装。
- class Colors:
- # Ultralytics color palette https://ultralytics.com/
- def __init__(self):
- # hex = matplotlib.colors.TABLEAU_COLORS.values()
- hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
- '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
- self.palette = [self.hex2rgb('#' + c) for c in hex]
- self.n = len(self.palette)
-
- def __call__(self, i, bgr=False):
- c = self.palette[int(i) % self.n]
- return (c[2], c[1], c[0]) if bgr else c
-
- @staticmethod
- def hex2rgb(h): # rgb order (PIL)
- return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))

hex:十六进制格式的颜色
palette:rgbgeshideyanse
n:数组长度
hex2rgb函数将以十六进制表示的颜色转换为RGB格式
call函数在调用时返回索引为i的颜色,当i超过n时用i模n的索引来取得颜色
- def check_font(font='Arial.ttf', size=10):
- # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
- font = Path(font)
- font = font if font.exists() else (CONFIG_DIR / font.name)
- try:
- return ImageFont.truetype(str(font) if font.exists() else font.name, size)
- except Exception as e: # download if missing
- url = "https://ultralytics.com/assets/" + font.name
- print(f'Downloading {url} to {font}...')
- torch.hub.download_url_to_file(url, str(font))
- return ImageFont.truetype(str(font), size)
font:检查的字体
该函数检查有否有对应的字体文件,没有从网上下载到对应的路径
PIL的ImageFont模块定义了相同名称的类,即ImageFont类。这个类的实力存储bitmap字体,用于ImageDraw类的text()方法,不多讲解,感兴趣的可以参考ImageFont 模块 — Pillow (PIL Fork) 8.4.0 文档
- class Annotator:
- check_font() # download TTF if necessary
-
- # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
- def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=True):
- assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
- self.pil = pil
- if self.pil: # use PIL
- self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
- self.draw = ImageDraw.Draw(self.im)
- self.font = check_font(font, size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
- self.fh = self.font.getsize('a')[1] - 3 # font height
- else: # use cv2
- self.im = im
- self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
-
- def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
- # Add one xyxy box to image with label
- if self.pil or not is_ascii(label):
- self.draw.rectangle(box, width=self.lw, outline=color) # box
- if label:
- w, h = self.font.getsize(label) # text width
- self.draw.rectangle([box[0], box[1] - self.fh, box[0] + w + 1, box[1] + 1], fill=color)
- # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
- self.draw.text((box[0], box[1] - h), label, fill=txt_color, font=self.font)
- else: # cv2
- c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
- cv2.rectangle(self.im, c1, c2, color, thickness=self.lw, lineType=cv2.LINE_AA)
- if label:
- tf = max(self.lw - 1, 1) # font thickness
- w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0]
- c2 = c1[0] + w, c1[1] - h - 3
- cv2.rectangle(self.im, c1, c2, color, -1, cv2.LINE_AA) # filled
- cv2.putText(self.im, label, (c1[0], c1[1] - 2), 0, self.lw / 3, txt_color, thickness=tf,
- lineType=cv2.LINE_AA)
-
- def rectangle(self, xy, fill=None, outline=None, width=1):
- # Add rectangle to image (PIL-only)
- self.draw.rectangle(xy, fill, outline, width)
-
- def text(self, xy, text, txt_color=(255, 255, 255)):
- # Add text to image (PIL-only)
- w, h = self.font.getsize(text) # text width, height
- self.draw.text((xy[0], xy[1] - h + 1), text, fill=txt_color, font=self.font)
-
- def result(self):
- # Return annotated image as array
- return np.asarray(self.im)

init方法:
im:图片
line_width:线宽
font_size:字体大小
font:字体名称
pil:是否使用pillow
如果使用pillow,将图片格式转换为pillow的格式,fh为字体高度
ImageDraw提供简单的二维图像Image物体,可以使用此模块创建新图像、对现有图像进行注释或润色,具体参考ImageDraw 模块 — Pillow (PIL Fork) 8.4.0 文档
lw为线宽
box_label方法:向图片中增加一个xyxy的box,并且加上标签
box:xyxy的box
label:标签
无论使用PIL或者opencv都是在对图像加一个box,其格式是xyxy,即box左上角的点坐标和右下角点的坐标,并且标注box的标签
rectangle 方法:
向图像中画一个长方形
text方法:
向图像中添加box的标签
result方法:
返回最终的图像,其格式是numpy数组
该类实现了向图片中画出预测框并且添加标签
如图是经过操作后的图像,标注出了预测框以及预测出来的类别以及置信度
- def hist2d(x, y, n=100):
- # 2d histogram used in labels.png and evolve.png
- xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
- hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
- xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
- yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
- return np.log(hist[xidx, yidx])
根据x,y的直方图分布,来返回绘制颜色,区间数量多的颜色更亮,反之更暗
x和y都是np数组
np.linspace(start,stop,num,endpoint,retstep,dtype)
在指定的间隔内返回均匀间隔的数字 ,返回num均匀分布的样本在[start,stop]之间
np.clip(a,a_min,a_max,out=None)是将a限定在a_min和a_max之间,当a大于a_max时返回a_max,a小于a_min返回a_min,否则返回a本身
np.histogram2d可以将两个二维数组做出它的直方图
- def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
- from scipy.signal import butter, filtfilt
-
- # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
- def butter_lowpass(cutoff, fs, order):
- nyq = 0.5 * fs
- normal_cutoff = cutoff / nyq
- return butter(order, normal_cutoff, btype='low', analog=False)
-
- b, a = butter_lowpass(cutoff, fs, order=order)
- return filtfilt(b, a, data) # forward-backward filter
data:原数据
cutoff:被丢掉的频率
fs:滤波器大小
这个函数实现了低通滤波,即保留图像中频率比较低的部分,丢掉频率高的部分,“低通”就是低频能够通过,高频无法通过。
butter为配置滤波器,filtfilt实现滤波
具体可参考官网scipy.signal.butter — SciPy v1.7.1 Manual
- def output_to_target(output):
- # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
- targets = []
- for i, o in enumerate(output):
- for *box, conf, cls in o.cpu().numpy():
- targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
- return np.array(targets)
output:模型的输出
该函数将模型的输出转换为我们想要的格式,即[batch_id,class_id,x,y,w,h,conf]
output的格式为[boxes,conf,cla],分别代表了预测框、置信度、类别
标签的格式为[batch_id,class_id,x,y,w,h,conf]* M,M为整个batch的预测框数量。
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
images:一个batch的图片
labels:一个batch的标签
paths:一个batch的文件名
fname:保存可视化之后大图的文件路径
names:类别名
max_size:限制每张可视化图片的最大图片大小
max_subplots:最多可视化batch_size=16张图片
- # Plot image grid with labels
- if isinstance(images, torch.Tensor):
- images = images.cpu().float().numpy()
- if isinstance(targets, torch.Tensor):
- targets = targets.cpu().numpy()
- if np.max(images[0]) <= 1:
- images *= 255.0 # de-normalise (optional)
将images和labels从tensor转换为numpy类型
如果images为0-1,将其乘上255转换为0-255
- bs, _, h, w = images.shape # batch size, _, height, width
- bs = min(bs, max_subplots) # limit plot images
- ns = np.ceil(bs ** 0.5) # number of subplots (square)
bs,h,w分别为batch_size,图片的高度、宽度
- # Build Image
- mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
- for i, im in enumerate(images):
- if i == max_subplots: # if last batch has fewer images than we expect
- break
- x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
- im = im.transpose(1, 2, 0)
- mosaic[y:y + h, x:x + w, :] = im
mosaic为初始化大图
对images进行遍历,x和y为转化为mosaic的像素位置
这块代码就是将images进行放大,复制到mosaic
- # Resize (optional)
- scale = max_size / ns / max(h, w)
- if scale < 1:
- h = math.ceil(scale * h)
- w = math.ceil(scale * w)
- mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
对mosaic进行resize,h和w分别为新的高和宽,scale为缩小倍数
- # Annotate
- fs = int((h + w) * ns * 0.01) # font size
- annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs)
- for i in range(i + 1):
- x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
- annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
- if paths:
- annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
- if len(targets) > 0:
- ti = targets[targets[:, 0] == i] # image targets
- boxes = xywh2xyxy(ti[:, 2:6]).T
- classes = ti[:, 1].astype('int')
- labels = ti.shape[1] == 6 # labels if no conf column
- conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
-
- if boxes.shape[1]:
- if boxes.max() <= 1.01: # if normalized with tolerance 0.01
- boxes[[0, 2]] *= w # scale to pixels
- boxes[[1, 3]] *= h
- elif scale < 1: # absolute coords need scale if image scales
- boxes *= scale
- boxes[[0, 2]] += x
- boxes[[1, 3]] += y
- for j, box in enumerate(boxes.T.tolist()):
- cls = classes[j]
- color = colors(cls)
- cls = names[cls] if names else cls
- if labels or conf[j] > 0.25: # 0.25 conf thresh
- label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
- annotator.box_label(box, label, color=color)
- annotator.im.save(fname) # save

接下来就是对图片进行标注,fs为font size,annotator为上面定义的类
x和y为左上角的点,然后用anatator画一个长方形,如果path不为空并标注出box的类别标签
image_targets为当前batch的标签,boxes、classes、labels、conf分别是预测框、类别、是否可视化标签、置信度,其中labels表示当image_targets.shape[1]==6时需要可视化的是标签而不是预测框。
如果预测框是归一化了的将其放大到原图大小,否则乘以scale_factor
接下来对boxes的坐标加上左上角的坐标,boxes原先的坐标是基于当前grid的左上角的相对坐标,加上左上角的坐标变换为全局坐标
接下来在子图上画框,cls、color为类别和颜色,如果是画预测框并且conf>0.25,则画出一个预测框,设置conf>0.25是为了去除掉那些重复预测出来的框。
最后将其保存在相应的路径下。
本篇文章比较重要的部分就是对图片进行画框和标注类别的处理,还有一些方法还没有介绍到,将在下一篇文章继续介绍这部分内容。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。