赞
踩
1-IOU(bboxes, anchors)
的方法,这里主要和原博主读取文件有些差别,将其中的读取年份删除,更方便读取自己的数据集,并整合至一个文件中,直接可运行。我本人的文件夹结构是这样的:
目录结构:
data
├── Annotations
├── images
├── ImageSets
├── labels
└── train.txt
import os from tqdm import tqdm # 安装包 pip install lxml -i https://pypi.tuna.tsinghua.edu.cn/simple from lxml import etree import glob import numpy as np import random from tqdm import tqdm from scipy.cluster.vq import kmeans class VOCDataSet(object): # 提取训练集图片对应的xml文件 def __init__(self, voc_root, txt_name: str = "train.txt"): # 拼接路径 标签所在路径 data/Annotations/ self.annotations_root = os.path.join(voc_root, "Annotations") # 拼接路径 data/train.txt 其中train.txt中是训练集图片的路径 txt_path = os.path.join(voc_root, txt_name) assert os.path.exists(txt_path), "not found {} file.".format(txt_name) # 通过训练集图片去寻找训练集对应的标签文件 with open(txt_path) as read: # data/images/xxxx.jpg line[12:-5]切片操作 表示取出字符串中的xxxx 拼接路径找到需要的data/Annotations/xxxx.xml self.xml_list = [os.path.join(self.annotations_root, line[12:-5] + ".xml") for line in read.readlines() if len(line.strip()) > 0] # check file assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path) for xml_path in self.xml_list: assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path) def __len__(self): return len(self.xml_list) def parse_xml_to_dict(self, xml): """ 将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict Args: xml: xml tree obtained by parsing XML file contents using lxml.etree Returns: Python dictionary holding XML contents. """ if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息 return {xml.tag: xml.text} result = {} for child in xml: child_result = self.parse_xml_to_dict(child) # 递归遍历标签信息 if child.tag != 'object': result[child.tag] = child_result[child.tag] else: if child.tag not in result: # 因为object可能有多个,所以需要放入列表里 result[child.tag] = [] result[child.tag].append(child_result[child.tag]) return {xml.tag: result} def get_info(self): im_wh_list = [] boxes_wh_list = [] for xml_path in tqdm(self.xml_list, desc="read data info."): # read xml with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) # xml_str.encode(‘utf-8’) # xml = etree.fromstring(xml_str.encode('utf-8')) data = self.parse_xml_to_dict(xml)["annotation"] im_height = int(data["size"]["height"]) im_width = int(data["size"]["width"]) wh = [] for obj in data["object"]: xmin = float(obj["bndbox"]["xmin"]) xmax = float(obj["bndbox"]["xmax"]) ymin = float(obj["bndbox"]["ymin"]) ymax = float(obj["bndbox"]["ymax"]) wh.append([(xmax - xmin) / im_width, (ymax - ymin) / im_height]) if len(wh) == 0: continue im_wh_list.append([im_width, im_height]) boxes_wh_list.append(wh) return im_wh_list, boxes_wh_list def wh_iou(wh1, wh2): # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2 wh1 = wh1[:, None] # [N,1,2] wh2 = wh2[None] # [1,M,2] inter = np.minimum(wh1, wh2).prod(2) # [N,M] return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter) def k_means(boxes, k, dist=np.median): """ yolo k-means methods refer: https://github.com/qqwweee/keras-yolo3/blob/master/kmeans.py Args: boxes: 需要聚类的bboxes k: 簇数(聚成几类) dist: 更新簇坐标的方法(默认使用中位数,比均值效果略好) """ box_number = boxes.shape[0] last_nearest = np.zeros((box_number,)) # 在所有的bboxes中随机挑选k个作为簇的中心。 clusters = boxes[np.random.choice(box_number, k, replace=False)] while True: # 计算每个bboxes离每个簇的距离 1-IOU(bboxes, anchors) distances = 1 - wh_iou(boxes, clusters) # 计算每个bboxes距离最近的簇中心 current_nearest = np.argmin(distances, axis=1) # 每个簇中元素不在发生变化说明以及聚类完毕 if (last_nearest == current_nearest).all(): break # clusters won't change for cluster in range(k): # 根据每个簇中的bboxes重新计算簇中心 clusters[cluster] = dist(boxes[current_nearest == cluster], axis=0) last_nearest = current_nearest return clusters def anchor_fitness(k: np.ndarray, wh: np.ndarray, thr: float): # mutation fitness r = wh[:, None] / k[None] x = np.minimum(r, 1. / r).min(2) # ratio metric # x = wh_iou(wh, k) # iou metric best = x.max(1) f = (best * (best > thr).astype(np.float32)).mean() # fitness bpr = (best > thr).astype(np.float32).mean() # best possible recall return f, bpr def main(img_size=512, n=9, thr=0.25, gen=1000): # 从数据集中读取所有图片的wh以及对应bboxes的wh dataset = VOCDataSet(voc_root="./data", txt_name="train.txt") im_wh, boxes_wh = dataset.get_info() # 最大边缩放到img_size im_wh = np.array(im_wh, dtype=np.float32) shapes = img_size * im_wh / im_wh.max(1, keepdims=True) wh0 = np.concatenate([l * s for s, l in zip(shapes, boxes_wh)]) # wh # Filter 过滤掉小目标 i = (wh0 < 3.0).any(1).sum() if i: print(f'WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.') wh = wh0[(wh0 >= 2.0).any(1)] # 只保留wh都大于等于2个像素的box # Kmeans calculation # print(f'Running kmeans for {n} anchors on {len(wh)} points...') # s = wh.std(0) # sigmas for whitening # k, dist = kmeans(wh / s, n, iter=30) # points, mean distance # assert len(k) == n, print(f'ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}') # k *= s k = k_means(wh, n) # 按面积排序 k = k[np.argsort(k.prod(1))] # sort small to large f, bpr = anchor_fitness(k, wh, thr) print("kmeans: " + " ".join([f"[{int(i[0])}, {int(i[1])}]" for i in k])) print(f"fitness: {f:.5f}, best possible recall: {bpr:.5f}") # Evolve # 遗传算法(在kmeans的结果基础上变异mutation) npr = np.random f, sh, mp, s = anchor_fitness(k, wh, thr)[0], k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma pbar = tqdm(range(gen), desc=f'Evolving anchors with Genetic Algorithm:') # progress bar for _ in pbar: v = np.ones(sh) while (v == 1).all(): # mutate until a change occurs (prevent duplicates) v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) kg = (k.copy() * v).clip(min=2.0) fg, bpr = anchor_fitness(kg, wh, thr) if fg > f: f, k = fg, kg.copy() pbar.desc = f'Evolving anchors with Genetic Algorithm: fitness = {f:.4f}' # 按面积排序 k = k[np.argsort(k.prod(1))] # sort small to large print("genetic: " + " ".join([f"[{int(i[0])}, {int(i[1])}]" for i in k])) print(f"fitness: {f:.5f}, best possible recall: {bpr:.5f}") # 具体参数可调整 if __name__ == "__main__": main(640, 9, 0.25, 1000)
若有如下报错:
ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.
修改以下代码:
xml = etree.fromstring(xml_str)----->
xml = etree.fromstring(xml_str.encode('utf-8'))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。