当前位置:   article > 正文

python提取COCO数据集中特定的类_coco数据集提取自己需要的类

coco数据集提取自己需要的类

记录一下提取Coco自行车类别的过程


1.安装pycocotools github地址:https://github.com/philferriere/cocoapi

 pip install git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI

2.提取其中的bicycle类的代码如下:

需要修改的地方

savepath

datasets_list

classes_names

dataDir

 使用的这篇博客中的代码

https://blog.csdn.net/weixin_38632246/article/details/97141364

  1. from pycocotools.coco import COCO
  2. import os
  3. import shutil
  4. from tqdm import tqdm
  5. # import skimage.io as io
  6. import matplotlib.pyplot as plt
  7. import cv2
  8. from PIL import Image, ImageDraw
  9. #提取出的类别的保存路径
  10. savepath="/media/deepnorth/14b6945d-9936-41a8-aeac-505b96fc2be8/COCO/"
  11. img_dir=savepath+'images/'
  12. anno_dir=savepath+'Annotations/'
  13. # datasets_list=['train2014', 'val2014']
  14. datasets_list=['train2014']
  15. #这里填写需要提取的类别,本人此处提取bicycle
  16. classes_names = ['bicycle']
  17. #原coco数据集的目录
  18. dataDir= '/media/deepnorth/14b6945d-9936-41a8-aeac-505b96fc2be8/COCO/'
  19. headstr = """\
  20. <annotation>
  21. <folder>VOC</folder>
  22. <filename>%s</filename>
  23. <source>
  24. <database>My Database</database>
  25. <annotation>COCO</annotation>
  26. <image>flickr</image>
  27. <flickrid>NULL</flickrid>
  28. </source>
  29. <owner>
  30. <flickrid>NULL</flickrid>
  31. <name>company</name>
  32. </owner>
  33. <size>
  34. <width>%d</width>
  35. <height>%d</height>
  36. <depth>%d</depth>
  37. </size>
  38. <segmented>0</segmented>
  39. """
  40. objstr = """\
  41. <object>
  42. <name>%s</name>
  43. <pose>Unspecified</pose>
  44. <truncated>0</truncated>
  45. <difficult>0</difficult>
  46. <bndbox>
  47. <xmin>%d</xmin>
  48. <ymin>%d</ymin>
  49. <xmax>%d</xmax>
  50. <ymax>%d</ymax>
  51. </bndbox>
  52. </object>
  53. """
  54. tailstr = '''\
  55. </annotation>
  56. '''
  57. #if the dir is not exists,make it,else delete it
  58. def mkr(path):
  59. if os.path.exists(path):
  60. shutil.rmtree(path)
  61. os.mkdir(path)
  62. else:
  63. os.mkdir(path)
  64. mkr(img_dir)
  65. mkr(anno_dir)
  66. def id2name(coco):
  67. classes=dict()
  68. for cls in coco.dataset['categories']:
  69. classes[cls['id']]=cls['name']
  70. return classes
  71. def write_xml(anno_path,head, objs, tail):
  72. f = open(anno_path, "w")
  73. f.write(head)
  74. for obj in objs:
  75. f.write(objstr%(obj[0],obj[1],obj[2],obj[3],obj[4]))
  76. f.write(tail)
  77. def save_annotations_and_imgs(coco,dataset,filename,objs):
  78. #eg:COCO_train2014_000000196610.jpg-->COCO_train2014_000000196610.xml
  79. anno_path=anno_dir+filename[:-3]+'xml'
  80. img_path=dataDir+dataset+'/'+filename
  81. print(img_path)
  82. dst_imgpath=img_dir+filename
  83. img=cv2.imread(img_path)
  84. #if (img.shape[2] == 1):
  85. # print(filename + " not a RGB image")
  86. # return
  87. shutil.copy(img_path, dst_imgpath)
  88. head=headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
  89. tail = tailstr
  90. write_xml(anno_path,head, objs, tail)
  91. def showimg(coco,dataset,img,classes,cls_id,show=True):
  92. global dataDir
  93. I=Image.open('%s/%s/%s'%(dataDir,dataset,img['file_name']))
  94. #通过id,得到注释的信息
  95. annIds = coco.getAnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
  96. # print(annIds)
  97. anns = coco.loadAnns(annIds)
  98. # print(anns)
  99. # coco.showAnns(anns)
  100. objs = []
  101. for ann in anns:
  102. class_name=classes[ann['category_id']]
  103. if class_name in classes_names:
  104. print(class_name)
  105. if 'bbox' in ann:
  106. bbox=ann['bbox']
  107. xmin = int(bbox[0])
  108. ymin = int(bbox[1])
  109. xmax = int(bbox[2] + bbox[0])
  110. ymax = int(bbox[3] + bbox[1])
  111. obj = [class_name, xmin, ymin, xmax, ymax]
  112. objs.append(obj)
  113. draw = ImageDraw.Draw(I)
  114. draw.rectangle([xmin, ymin, xmax, ymax])
  115. if show:
  116. plt.figure()
  117. plt.axis('off')
  118. plt.imshow(I)
  119. plt.show()
  120. return objs
  121. for dataset in datasets_list:
  122. #./COCO/annotations/instances_train2014.json
  123. annFile='{}/annotations/instances_{}.json'.format(dataDir,dataset)
  124. #COCO API for initializing annotated data
  125. coco = COCO(annFile)
  126. #show all classes in coco
  127. classes = id2name(coco)
  128. print(classes)
  129. #[1, 2, 3, 4, 6, 8]
  130. classes_ids = coco.getCatIds(catNms=classes_names)
  131. print(classes_ids)
  132. for cls in classes_names:
  133. #Get ID number of this class
  134. cls_id=coco.getCatIds(catNms=[cls])
  135. img_ids=coco.getImgIds(catIds=cls_id)
  136. print(cls,len(img_ids))
  137. # imgIds=img_ids[0:10]
  138. for imgId in tqdm(img_ids):
  139. img = coco.loadImgs(imgId)[0]
  140. filename = img['file_name']
  141. # print(filename)
  142. objs=showimg(coco, dataset, img, classes,classes_ids,show=False)
  143. print(objs)
  144. save_annotations_and_imgs(coco, dataset, filename, objs)

 

COCO数据集2014

代码执行完之后会生成对应的  images文件夹和 Annotations(.xml)文件夹

 

有了这两个文件就可以利用voc的代码转换为yolo目标检测的txt标签文件

相关代码

需要修改的参数

classes

data_path

list_file

in_file

out_file

  1. import xml.etree.ElementTree as ET
  2. import pickle
  3. import os
  4. from os import listdir, getcwd
  5. from os.path import join
  6. classes = ["bicycle"]
  7. def convert(size, box):
  8. dw = 1./(size[0])
  9. dh = 1./(size[1])
  10. x = (box[0] + box[1])/2.0 - 1
  11. y = (box[2] + box[3])/2.0 - 1
  12. w = box[1] - box[0]
  13. h = box[3] - box[2]
  14. x = x*dw
  15. w = w*dw
  16. y = y*dh
  17. h = h*dh
  18. return (x,y,w,h)
  19. def convert_annotation(image_id):
  20. in_file = open('coco_voc_val/Annotations/%s.xml'%(image_id))
  21. out_file = open('coco_voc_val/labels/%s.txt'%(image_id), 'w')
  22. tree=ET.parse(in_file)
  23. root = tree.getroot()
  24. size = root.find('size')
  25. w = int(size.find('width').text)
  26. h = int(size.find('height').text)
  27. for obj in root.iter('object'):
  28. difficult = obj.find('difficult').text
  29. cls = obj.find('name').text
  30. print(cls)
  31. if cls not in classes or int(difficult)==1:
  32. continue
  33. cls_id = classes.index(cls)
  34. xmlbox = obj.find('bndbox')
  35. b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
  36. bb = convert((w,h), b)
  37. out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
  38. data_path = '/media/COCO/coco_voc_val/images'
  39. img_names = os.listdir(data_path)
  40. list_file = open('2014_val.txt', 'w')
  41. for img_name in img_names:
  42. if not os.path.exists('coco_voc_val/labels'):
  43. os.makedirs('coco_voc_val/labels')
  44. list_file.write('/media/COCO/coco_voc_val/images/%s\n'%img_name)
  45. image_id = img_name[:-4]
  46. convert_annotation(image_id)
  47. list_file.close()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/308691
推荐阅读
相关标签
  

闽ICP备14008679号