当前位置:   article > 正文

CUB200-2011鸟类细粒度数据集训练集和测试集划分python代码

cub200-2011

 CUB200-2011数据集介绍:

        该数据集由加州理工学院再2010年提出的细粒度数据集,也是目前细粒度分类识别研究的基准图像数据集。

        该数据集共有11788张鸟类图像,包含200类鸟类子类,其中训练数据集有5994张图像,测试集有5794张图像,每张图像均提供了图像类标记信息,图像中鸟的bounding box,鸟的关键part信息,以及鸟类的属性信息,数据集如下图所示。


下载的数据集中,包含了如下文件:

bounding_boxes.txt;classes.txt;image_class_labels.txt; images.txt; train_test_split.txt.

其中,bounding_boxes.txt为图像中鸟类的边界框信息;classes.txt为鸟类的类别信息,共有200类; image_class_labels.txt为图像标签和所属类别标签信息;images.txt为图像的标签和图像路径信息;train_test_split.txt为训练集和测试集划分。

本博客主要是根据train_test_split.txt文件和images.txt文件将原始下载的CUB200-2011数据集划分为训练集和测试集。在深度学习Pytorch框架下采用ImageFolder和DataLoader读取数据集较为方便。相关的python代码如下:

(1) CUB200-2011训练集和测试集划分代码

  1. # *_*coding: utf-8 *_*
  2. # author --liming--
  3. """
  4. 读取images.txt文件,获得每个图像的标签
  5. 读取train_test_split.txt文件,获取每个图像的train, test标签.其中1为训练,0为测试.
  6. """
  7. import os
  8. import shutil
  9. import numpy as np
  10. import config
  11. import time
  12. time_start = time.time()
  13. # 文件路径
  14. path_images = config.path + 'images.txt'
  15. path_split = config.path + 'train_test_split.txt'
  16. trian_save_path = config.path + 'dataset/train/'
  17. test_save_path = config.path + 'dataset/test/'
  18. # 读取images.txt文件
  19. images = []
  20. with open(path_images,'r') as f:
  21. for line in f:
  22. images.append(list(line.strip('\n').split(',')))
  23. # 读取train_test_split.txt文件
  24. split = []
  25. with open(path_split, 'r') as f_:
  26. for line in f_:
  27. split.append(list(line.strip('\n').split(',')))
  28. # 划分
  29. num = len(images) # 图像的总个数
  30. for k in range(num):
  31. file_name = images[k][0].split(' ')[1].split('/')[0]
  32. aaa = int(split[k][0][-1])
  33. if int(split[k][0][-1]) == 1: # 划分到训练集
  34. #判断文件夹是否存在
  35. if os.path.isdir(trian_save_path + file_name):
  36. shutil.copy(config.path + 'images/' + images[k][0].split(' ')[1], trian_save_path+file_name+'/'+images[k][0].split(' ')[1].split('/')[1])
  37. else:
  38. os.makedirs(trian_save_path + file_name)
  39. shutil.copy(config.path + 'images/' + images[k][0].split(' ')[1], trian_save_path + file_name + '/' + images[k][0].split(' ')[1].split('/')[1])
  40. print('%s处理完毕!' % images[k][0].split(' ')[1].split('/')[1])
  41. else:
  42. #判断文件夹是否存在
  43. if os.path.isdir(test_save_path + file_name):
  44. aaaa = config.path + 'images/' + images[k][0].split(' ')[1]
  45. bbbb = test_save_path+file_name+'/'+images[k][0].split(' ')[1]
  46. shutil.copy(config.path + 'images/' + images[k][0].split(' ')[1], test_save_path+file_name+'/'+images[k][0].split(' ')[1].split('/')[1])
  47. else:
  48. os.makedirs(test_save_path + file_name)
  49. shutil.copy(config.path + 'images/' + images[k][0].split(' ')[1], test_save_path + file_name + '/' + images[k][0].split(' ')[1].split('/')[1])
  50. print('%s处理完毕!' % images[k][0].split(' ')[1].split('/')[1])
  51. time_end = time.time()
  52. print('CUB200训练集和测试集划分完毕, 耗时%s!!' % (time_end - time_start))

config文件

  1. # *_*coding: utf-8 *_*
  2. # author --liming--
  3. path = '/media/lm/C3F680DFF08EB695/细粒度数据集/birds/CUB200/CUB_200_2011/'
  4. ROOT_TRAIN = path + 'images/train/'
  5. ROOT_TEST = path + 'images/test/'
  6. BATCH_SIZE = 16

(2) 利用Pytorch方式读取数据

  1. # *_*coding: utf-8 *_*
  2. # author --liming--
  3. """
  4. 用于已下载数据集的转换,便于pytorch的读取
  5. """
  6. import torch
  7. import torchvision
  8. import config
  9. from torchvision import datasets, transforms
  10. data_transform = transforms.Compose([
  11. transforms.Resize(224),
  12. transforms.ToTensor(),
  13. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  14. ])
  15. def train_data_load():
  16. # 训练集
  17. root_train = config.ROOT_TRAIN
  18. train_dataset = torchvision.datasets.ImageFolder(root_train,
  19. transform=data_transform)
  20. CLASS = train_dataset.class_to_idx
  21. print('训练数据label与文件名的关系:', CLASS)
  22. train_loader = torch.utils.data.DataLoader(train_dataset,
  23. batch_size=config.BATCH_SIZE,
  24. shuffle=True)
  25. return CLASS, train_loader
  26. def test_data_load():
  27. # 测试集
  28. root_test = config.ROOT_TEST
  29. test_dataset = torchvision.datasets.ImageFolder(root_test,
  30. transform=data_transform)
  31. CLASS = test_dataset.class_to_idx
  32. print('测试数据label与文件名的关系:',CLASS)
  33. test_loader = torch.utils.data.DataLoader(test_dataset,
  34. batch_size=config.BATCH_SIZE,
  35. shuffle=True)
  36. return CLASS, test_loader
  37. if __name__ == '__main___':
  38. train_data_load()
  39. test_data_load()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/248989
推荐阅读
相关标签
  

闽ICP备14008679号