赞
踩
from PIL import Image
import numpy as np
img = Image.fromarray(np.array(gt_masks.cpu()*255).astype(‘uint8’)).convert(‘RGB’)
img.save(‘picture.jpg’)
数据集官网 :https://captain-whu.github.io/DOAI2019/dataset.html
数据集工具包 :https://github.com/CAPTAIN-WHU/DOTA_devkit
DOTA-v1.5数据集一共有16个类别,包含40万个带注释的对象实例。
训练集:1141张
验证集:458张
16个类别分别是:飞机,轮船,储罐,棒球场,网球场,篮球场,地面跑道,港口,桥梁,小型车辆,大型车辆,直升机,环形交叉路口,足球场,游泳池和集装箱起重机。
plane, ship, storage tank, baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, small vehicle, large vehicle, helicopter, roundabout, soccer ball field, swimming pool and container crane.
原图像像素非常大,要训练就得切割成小的patch,数据集工具包有提供相关处理的代码。
也就是官方代码提供的DOTA.py
, 该脚本可以对你想要的类的图片进行可视化。
#The code is used for visulization, inspired from cocoapi # Licensed under the Simplified BSD License [see bsd.txt] import os import matplotlib.pyplot as plt from matplotlib.collections import PatchCollection from matplotlib.patches import Polygon, Circle import numpy as np import dota_utils as util from collections import defaultdict import cv2 def _isArrayLike(obj): if type(obj) == str: return False return hasattr(obj, '__iter__') and hasattr(obj, '__len__') class DOTA: def __init__(self, basepath): self.basepath = basepath self.labelpath = os.path.join(basepath, 'labelTxt') self.imagepath = os.path.join(basepath, 'images') self.imgpaths = util.GetFileFromThisRootDir(self.labelpath) # 每个图片txt文件的绝对路径 self.imglist = [util.custombasename(x) for x in self.imgpaths] # 每个图片的前缀名字 比如P1506 self.catToImgs = defaultdict(list) self.ImgToAnns = defaultdict(list) # 存放每个类别下 图片名字 self.createIndex() def createIndex(self): for filename in self.imgpaths: # 对于每个文件txt处理他的标注角点 存储为字典 name为类别 poly为坐标 area为区域类似形状 objects = util.parse_dota_poly(filename) imgid = util.custombasename(filename) self.ImgToAnns[imgid] = objects for obj in objects: cat = obj['name'] self.catToImgs[cat].append(imgid) def getImgIds(self, catNms=[]): """ :param catNms: category names :return: all the image ids contain the categories """ catNms = catNms if _isArrayLike(catNms) else [catNms] if len(catNms) == 0: return self.imglist else: imgids = [] for i, cat in enumerate(catNms): if i == 0: imgids = set(self.catToImgs[cat]) else: imgids &= set(self.catToImgs[cat]) return list(imgids) def loadAnns(self, catNms=[], imgId = None, difficult=None): """ :param catNms: category names :param imgId: the img to load anns :return: objects """ catNms = catNms if _isArrayLike(catNms) else [catNms] objects = self.ImgToAnns[imgId] if len(catNms) == 0: return objects outobjects = [obj for obj in objects if (obj['name'] in catNms)] return outobjects def showAnns(self, objects, imgId, range): """ :param catNms: category names :param objects: objects to show :param imgId: img to show :param range: display range in the img :return: """ img = self.loadImgs(imgId)[0] plt.imshow(img) plt.axis('off') ax = plt.gca() ax.set_autoscale_on(False) polygons = [] color = [] circles = [] r = 5 for obj in objects: c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] poly = obj['poly'] polygons.append(Polygon(poly)) color.append(c) point = poly[0] circle = Circle((point[0], point[1]), r) circles.append(circle) p = PatchCollection(polygons, facecolors=color, linewidths=0, alpha=0.4) ax.add_collection(p) p = PatchCollection(polygons, facecolors='none', edgecolors=color, linewidths=2) ax.add_collection(p) p = PatchCollection(circles, facecolors=
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。