当前位置:   article > 正文

提取COCO 数据集的部分类

提取COCO 数据集的部分类

1.python提取COCO数据集中特定的类

安装pycocotools github地址:https://github.com/philferriere/cocoapi

  1. pip install git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI

若报错,pip install git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI

换成

pip install git+git://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI

实在不行的话,手动下载

  1. git clone https://github.com/pdollar/coco.git
  2. cd coco/PythonAPI
  3. python setup.py build_ext --inplace #安装到本地
  4. python setup.py build_ext install # 安装到Python环境中

没有的库自己pip

注意skimage用pip install scikit-image -i https://pypi.tuna.tsinghua.edu.cn/simple

提取特定的类别如下:

  1. # conding='utf-8'
  2. from pycocotools.coco import COCO
  3. import os
  4. import shutil
  5. from tqdm import tqdm
  6. import skimage.io as io
  7. import matplotlib.pyplot as plt
  8. import cv2
  9. from PIL import Image, ImageDraw
  10. #the path you want to save your results for coco to voc
  11. savepath="/opt/10T/home/asc005/YangMingxiang/DenseCLIP_/data/COCO/" #save_path
  12. img_dir=savepath+'images/'
  13. anno_dir=savepath+'Annotations/'
  14. # datasets_list=['train2014', 'val2014']
  15. datasets_list=['train2017', 'val2017']
  16. classes_names = ['sheep'] #coco
  17. #Store annotations and train2014/val2014/... in this folder
  18. dataDir= '/opt/10T/home/asc005/YangMingxiang/DenseCLIP_/data/coco/' #origin coco
  19. headstr = """\
  20. <annotation>
  21. <folder>VOC</folder>
  22. <filename>%s</filename>
  23. <source>
  24. <database>My Database</database>
  25. <annotation>COCO</annotation>
  26. <image>flickr</image>
  27. <flickrid>NULL</flickrid>
  28. </source>
  29. <owner>
  30. <flickrid>NULL</flickrid>
  31. <name>company</name>
  32. </owner>
  33. <size>
  34. <width>%d</width>
  35. <height>%d</height>
  36. <depth>%d</depth>
  37. </size>
  38. <segmented>0</segmented>
  39. """
  40. objstr = """\
  41. <object>
  42. <name>%s</name>
  43. <pose>Unspecified</pose>
  44. <truncated>0</truncated>
  45. <difficult>0</difficult>
  46. <bndbox>
  47. <xmin>%d</xmin>
  48. <ymin>%d</ymin>
  49. <xmax>%d</xmax>
  50. <ymax>%d</ymax>
  51. </bndbox>
  52. </object>
  53. """
  54. tailstr = '''\
  55. </annotation>
  56. '''
  57. #if the dir is not exists,make it,else delete it
  58. def mkr(path):
  59. if os.path.exists(path):
  60. shutil.rmtree(path)
  61. os.mkdir(path)
  62. else:
  63. os.mkdir(path)
  64. mkr(img_dir)
  65. mkr(anno_dir)
  66. def id2name(coco):
  67. classes=dict()
  68. for cls in coco.dataset['categories']:
  69. classes[cls['id']]=cls['name']
  70. return classes
  71. def write_xml(anno_path,head, objs, tail):
  72. f = open(anno_path, "w")
  73. f.write(head)
  74. for obj in objs:
  75. f.write(objstr%(obj[0],obj[1],obj[2],obj[3],obj[4]))
  76. f.write(tail)
  77. def save_annotations_and_imgs(coco,dataset,filename,objs):
  78. #eg:COCO_train2014_000000196610.jpg-->COCO_train2014_000000196610.xml
  79. anno_path=anno_dir+filename[:-3]+'xml'
  80. img_path=dataDir+dataset+'/'+filename
  81. print(img_path)
  82. dst_imgpath=img_dir+filename
  83. img=cv2.imread(img_path)
  84. #if (img.shape[2] == 1):
  85. # print(filename + " not a RGB image")
  86. # return
  87. shutil.copy(img_path, dst_imgpath)
  88. head=headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
  89. tail = tailstr
  90. write_xml(anno_path,head, objs, tail)
  91. def showimg(coco,dataset,img,classes,cls_id,show=True):
  92. global dataDir
  93. I=Image.open('%s/%s/%s'%(dataDir,dataset,img['file_name']))
  94. annIds = coco.getAnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
  95. # print(annIds)
  96. anns = coco.loadAnns(annIds)
  97. # print(anns)
  98. # coco.showAnns(anns)
  99. objs = []
  100. for ann in anns:
  101. class_name=classes[ann['category_id']]
  102. if class_name in classes_names:
  103. print(class_name)
  104. if 'bbox' in ann:
  105. bbox=ann['bbox']
  106. xmin = int(bbox[0])
  107. ymin = int(bbox[1])
  108. xmax = int(bbox[2] + bbox[0])
  109. ymax = int(bbox[3] + bbox[1])
  110. obj = [class_name, xmin, ymin, xmax, ymax]
  111. objs.append(obj)
  112. draw = ImageDraw.Draw(I)
  113. draw.rectangle([xmin, ymin, xmax, ymax])
  114. if show:
  115. plt.figure()
  116. plt.axis('off')
  117. plt.imshow(I)
  118. plt.show()
  119. return objs
  120. for dataset in datasets_list:
  121. #./COCO/annotations/instances_train2014.json
  122. annFile='{}/annotations/instances_{}.json'.format(dataDir,dataset)
  123. #COCO API for initializing annotated data
  124. coco = COCO(annFile)
  125. #show all classes in coco
  126. classes = id2name(coco)
  127. print(classes)
  128. #[1, 2, 3, 4, 6, 8]
  129. classes_ids = coco.getCatIds(catNms=classes_names)
  130. print(classes_ids)
  131. for cls in classes_names:
  132. #Get ID number of this class
  133. cls_id=coco.getCatIds(catNms=[cls])
  134. img_ids=coco.getImgIds(catIds=cls_id)
  135. print(cls,len(img_ids))
  136. # imgIds=img_ids[0:10]
  137. for imgId in tqdm(img_ids):
  138. img = coco.loadImgs(imgId)[0]
  139. filename = img['file_name']
  140. # print(filename)
  141. objs=showimg(coco, dataset, img, classes,classes_ids,show=False)
  142. print(objs)
  143. save_annotations_and_imgs(coco, dataset, filename, objs)

然后就可以了

2. 将上面获取的数据集划分为训练集和测试集
  1. #conding='utf-8'
  2. import os
  3. import random
  4. from shutil import copy2
  5. # origin
  6. image_original_path = "/opt/10T/home/asc005/YangMingxiang/DenseCLIP_/data/COCO/images"
  7. label_original_path = "/opt/10T/home/asc005/YangMingxiang/DenseCLIP_/data/COCO/Annotations"
  8. # parent_path = os.path.dirname(os.getcwd())
  9. # parent_path = "D:\\AI_Find"
  10. # train_image_path = os.path.join(parent_path, "image_data/seed/train/images/")
  11. # train_label_path = os.path.join(parent_path, "image_data/seed/train/labels/")
  12. train_image_path = os.path.join("/opt/10T/home/asc005/YangMingxiang/DenseCLIP_/data/COCO/train2017")
  13. train_label_path = os.path.join("/opt/10T/home/asc005/YangMingxiang/DenseCLIP_/data/COCO/annotations/train2017")
  14. test_image_path = os.path.join("/opt/10T/home/asc005/YangMingxiang/DenseCLIP_/data/COCO/val2017")
  15. test_label_path = os.path.join("/opt/10T/home/asc005/YangMingxiang/DenseCLIP_/data/COCO/annotations/val2017")
  16. # test_image_path = os.path.join(parent_path, 'image_data/seed/val/images/')
  17. # test_label_path = os.path.join(parent_path, 'image_data/seed/val/labels/')
  18. def mkdir():
  19. if not os.path.exists(train_image_path):
  20. os.makedirs(train_image_path)
  21. if not os.path.exists(train_label_path):
  22. os.makedirs(train_label_path)
  23. if not os.path.exists(test_image_path):
  24. os.makedirs(test_image_path)
  25. if not os.path.exists(test_label_path):
  26. os.makedirs(test_label_path)
  27. def main():
  28. mkdir()
  29. all_image = os.listdir(image_original_path)
  30. for i in range(len(all_image)):
  31. num = random.randint(1,5)
  32. if num != 2:
  33. copy2(os.path.join(image_original_path, all_image[i]), train_image_path)
  34. train_index.append(i)
  35. else:
  36. copy2(os.path.join(image_original_path, all_image[i]), test_image_path)
  37. val_index.append(i)
  38. all_label = os.listdir(label_original_path)
  39. for i in train_index:
  40. copy2(os.path.join(label_original_path, all_label[i]), train_label_path)
  41. for i in val_index:
  42. copy2(os.path.join(label_original_path, all_label[i]), test_label_path)
  43. if __name__ == '__main__':
  44. train_index = []
  45. val_index = []
  46. main()
3.将上一步提取的COCO 某一类 xml转为COCO标准的json文件:
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/8/27 10:48
  3. # @Author :Rock
  4. # @File : voc2coco.py
  5. # just for object detection
  6. import xml.etree.ElementTree as ET
  7. import os
  8. import json
  9. coco = dict()
  10. coco['images'] = []
  11. coco['type'] = 'instances'
  12. coco['annotations'] = []
  13. coco['categories'] = []
  14. category_set = dict()
  15. image_set = set()
  16. category_item_id = 0
  17. image_id = 0
  18. annotation_id = 0
  19. def addCatItem(name):
  20. global category_item_id
  21. category_item = dict()
  22. category_item['supercategory'] = 'none'
  23. category_item_id += 1
  24. category_item['id'] = category_item_id
  25. category_item['name'] = name
  26. coco['categories'].append(category_item)
  27. category_set[name] = category_item_id
  28. return category_item_id
  29. def addImgItem(file_name, size):
  30. global image_id
  31. if file_name is None:
  32. raise Exception('Could not find filename tag in xml file.')
  33. if size['width'] is None:
  34. raise Exception('Could not find width tag in xml file.')
  35. if size['height'] is None:
  36. raise Exception('Could not find height tag in xml file.')
  37. img_id = "%04d" % image_id
  38. image_id += 1
  39. image_item = dict()
  40. image_item['id'] = int(img_id)
  41. # image_item['id'] = image_id
  42. image_item['file_name'] = file_name
  43. image_item['width'] = size['width']
  44. image_item['height'] = size['height']
  45. coco['images'].append(image_item)
  46. image_set.add(file_name)
  47. return image_id
  48. def addAnnoItem(object_name, image_id, category_id, bbox):
  49. global annotation_id
  50. annotation_item = dict()
  51. annotation_item['segmentation'] = []
  52. seg = []
  53. # bbox[] is x,y,w,h
  54. # left_top
  55. seg.append(bbox[0])
  56. seg.append(bbox[1])
  57. # left_bottom
  58. seg.append(bbox[0])
  59. seg.append(bbox[1] + bbox[3])
  60. # right_bottom
  61. seg.append(bbox[0] + bbox[2])
  62. seg.append(bbox[1] + bbox[3])
  63. # right_top
  64. seg.append(bbox[0] + bbox[2])
  65. seg.append(bbox[1])
  66. annotation_item['segmentation'].append(seg)
  67. annotation_item['area'] = bbox[2] * bbox[3]
  68. annotation_item['iscrowd'] = 0
  69. annotation_item['ignore'] = 0
  70. annotation_item['image_id'] = image_id
  71. annotation_item['bbox'] = bbox
  72. annotation_item['category_id'] = category_id
  73. annotation_id += 1
  74. annotation_item['id'] = annotation_id
  75. coco['annotations'].append(annotation_item)
  76. def parseXmlFiles(xml_path):
  77. for f in os.listdir(xml_path):
  78. if not f.endswith('.xml'):
  79. continue
  80. bndbox = dict()
  81. size = dict()
  82. current_image_id = None
  83. current_category_id = None
  84. file_name = None
  85. size['width'] = None
  86. size['height'] = None
  87. size['depth'] = None
  88. xml_file = os.path.join(xml_path, f)
  89. # print(xml_file)
  90. tree = ET.parse(xml_file)
  91. root = tree.getroot()
  92. if root.tag != 'annotation':
  93. raise Exception('pascal voc xml root element should be annotation, rather than {}'.format(root.tag))
  94. # elem is <folder>, <filename>, <size>, <object>
  95. for elem in root:
  96. current_parent = elem.tag
  97. current_sub = None
  98. object_name = None
  99. if elem.tag == 'folder':
  100. continue
  101. if elem.tag == 'filename':
  102. file_name = elem.text
  103. if file_name in category_set:
  104. raise Exception('file_name duplicated')
  105. # add img item only after parse <size> tag
  106. elif current_image_id is None and file_name is not None and size['width'] is not None:
  107. if file_name not in image_set:
  108. current_image_id = addImgItem(file_name, size)
  109. # print('add image with {} and {}'.format(file_name, size))
  110. else:
  111. raise Exception('duplicated image: {}'.format(file_name))
  112. # subelem is <width>, <height>, <depth>, <name>, <bndbox>
  113. for subelem in elem:
  114. bndbox['xmin'] = None
  115. bndbox['xmax'] = None
  116. bndbox['ymin'] = None
  117. bndbox['ymax'] = None
  118. current_sub = subelem.tag
  119. if current_parent == 'object' and subelem.tag == 'name':
  120. object_name = subelem.text
  121. if object_name not in category_set:
  122. current_category_id = addCatItem(object_name)
  123. else:
  124. current_category_id = category_set[object_name]
  125. elif current_parent == 'size':
  126. if size[subelem.tag] is not None:
  127. raise Exception('xml structure broken at size tag.')
  128. size[subelem.tag] = int(subelem.text)
  129. # option is <xmin>, <ymin>, <xmax>, <ymax>, when subelem is <bndbox>
  130. for option in subelem:
  131. if current_sub == 'bndbox':
  132. if bndbox[option.tag] is not None:
  133. raise Exception('xml structure corrupted at bndbox tag.')
  134. bndbox[option.tag] = int(option.text)
  135. # only after parse the <object> tag
  136. if bndbox['xmin'] is not None:
  137. if object_name is None:
  138. raise Exception('xml structure broken at bndbox tag')
  139. if current_image_id is None:
  140. raise Exception('xml structure broken at bndbox tag')
  141. if current_category_id is None:
  142. raise Exception('xml structure broken at bndbox tag')
  143. bbox = []
  144. # x
  145. bbox.append(bndbox['xmin'])
  146. # y
  147. bbox.append(bndbox['ymin'])
  148. # w
  149. bbox.append(bndbox['xmax'] - bndbox['xmin'])
  150. # h
  151. bbox.append(bndbox['ymax'] - bndbox['ymin'])
  152. # print('add annotation with {},{},{},{}'.format(object_name, current_image_id, current_category_id,
  153. # bbox))
  154. addAnnoItem(object_name, current_image_id, current_category_id, bbox)
  155. if __name__ == '__main__':
  156. #修改这里的两个地址,一个是xml文件的父目录;一个是生成的json文件的绝对路径
  157. xml_path = r'G:\dataset\COCO\person\coco_val2014\annotations\\'
  158. json_file = r'G:\dataset\COCO\person\coco_val2014\instances_val2014.json'
  159. parseXmlFiles(xml_path)
  160. json.dump(coco, open(json_file, 'w'))

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/622710
推荐阅读
相关标签
  

闽ICP备14008679号