当前位置:   article > 正文

CIFAR-10、CIFAR-100数据集解析及可视化_cifar-100下载地址

cifar-100下载地址

CIFAR-10、CIFAR-100是两个常用的图像分类数据集,因为其经常被使用,很多库都有该数据集的加载方法,一般直接调用即可直接构造训练、测试数据集。然而,这两个数据集到底长什么样子,我们如何用自己的方法把它提取出来呢?今天,就尝试用我们自定义的方法来提取数据集中的图片、标签、文件名等信息,并进行一个可视化。

 1、数据集获取

第一步,我们要先把数据下载下来,放到本地。

原下载地址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz,http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz。

如果因网络问题下载不下来,可以在我上传的资源里面下载:https://download.csdn.net/download/oYeZhou/12711792

2、数据集文件目录结构

CIFAR-10以及CIFAR-100均是通过将图片数据拉伸到3072维,然后堆叠起来,与其对应的label和文件名以字典形式存储起来,然后序列化到文件中的。

序列化之后的文件目录结构分别如下:

CIFAR-10目录结构
CIFAR-100目录结构​​​​

可见,CIFAR-10是把训练集分为5个batch,测试集单独一个batch;CIFAR-100是训练集、测试集分别一个batch;batchs.meta、meta分别存储的是CIFAR-100、CIFAR-100的类别数据。

3、数据集存储数据结构

我们可以利用pickle对文件进行反序列化,以查看其数据结构:

  1. import pickle
  2. filename = '../cifar10/cifar-10-batches-py/test_batch'
  3. with open(filename,'rb') as f:
  4. dataset = pickle.load(f, encoding='bytes')
  5. print(type(dataset))
  6. # out: <class 'dict'>

运行上述代码即可打印出数据结构。可以发现是dict类型的,然后即可打印出其keys:

  1. print(dataset.keys())
  2. # out: dict_keys([b'batch_label', b'labels', b'data', b'filenames'])

根据key即可提取相应的数据了:

  1. data = dataset[b'data']
  2. labels = dataset[b'labels']
  3. img_names = dataset[b'filenames']

其他的文件类似,可以通过这种方式提取数据结构及其内容。现给出我得到的信息:

  • CIFAR-10有10个类别,CIFAR-100有20个大类别、100个小类别;
  • 两个数据集的训练、测试集均分别为:50000、10000;
  • 图像数据均为拉伸到3072位的数组,需要reshape到3*32*32,3个通道的顺序为:RGB;

4、提取数据

4.1、提取类别信息

  1. def load_labels_name(filename):
  2. """使用pickle反序列化labels文件,得到存储内容
  3. cifar10的label文件为“batches.meta”,cifar100则为“meta”
  4. 反序列化之后得到字典对象,可根据key取出相应内容
  5. """
  6. with open(filename, 'rb') as f:
  7. obj = pickle.load(f)
  8. return obj

cifar10的label文件为“batches.meta”,以该文件为例,使用上述方法提取的信息为:

  1. {'label_names': ['airplane',
  2. 'automobile',
  3. 'bird',
  4. 'cat',
  5. 'deer',
  6. 'dog',
  7. 'frog',
  8. 'horse',
  9. 'ship',
  10. 'truck'],
  11. 'num_cases_per_batch': 10000,
  12. 'num_vis': 3072}

4.2、提取图像、标签、文件名

  1. def load_data_cifar(filename, mode='cifar10'):
  2. """ load data and labels information from cifar10 and cifar100
  3. cifar10 keys(): dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
  4. cifar100 keys(): dict_keys([b'filenames', b'batch_label', b'fine_labels', b'coarse_labels', b'data'])
  5. """
  6. with open(filename,'rb') as f:
  7. dataset = pickle.load(f, encoding='bytes')
  8. if mode == 'cifar10':
  9. data = dataset[b'data']
  10. labels = dataset[b'labels']
  11. img_names = dataset[b'filenames']
  12. elif mode == 'cifar100':
  13. data = dataset[b'data']
  14. labels = dataset[b'fine_labels']
  15. img_names = dataset[b'filenames']
  16. else:
  17. print("mode should be in ['cifar10', 'cifar100']")
  18. return None, None, None
  19. return data, labels, img_names

该方法可以提取一个batch文件中的图像、标签、文件名信息并返回。其中,返回的data是N*3072维的,每一行代表一张图片;label为类别标签,如果是CIFAR-100,返回的是数值范围在0~99之间的一个列表,为100类的小类别,可修改“labels = dataset[b'fine_labels']”这行代码为“labels = dataset[b'coarse_labels']”来提取大类别。

5、将图片数组转换为图片

通过4.2节中的代码,可以提取图像的数据,对该数据进行reshape,可以得到[channel, width, height]格式的数组:

imgs_cifar10_train = data_cifar10_train.reshape(data_cifar10_train.shape[0],3,32,32)

对每个这样的数组可分别转换为一张图片: 

  1. def to_pil(data):
  2. r = Image.fromarray(data[0])
  3. g = Image.fromarray(data[1])
  4. b = Image.fromarray(data[2])
  5. pil_img = Image.merge('RGB', (r,g,b))
  6. return pil_img

6、可视化每个类别的样本

先上可视化的代码:

  1. def random_visualize(imgs, labels, label_names):
  2. figure = plt.figure(figsize=(len(label_names),10))
  3. idxs = list(range(len(imgs)))
  4. np.random.shuffle(idxs)
  5. count = [0]*len(label_names)
  6. for idx in idxs:
  7. label = labels[idx]
  8. if count[label]>=10:
  9. continue
  10. if sum(count)>10 * len(label_names):
  11. break
  12. img = to_pil(imgs[idx])
  13. label_name = label_names[label]
  14. subplot_idx = count[label] * len(label_names) + label + 1
  15. print(label, subplot_idx)
  16. plt.subplot(10,len(label_names), subplot_idx)
  17. plt.imshow(img)
  18. plt.xticks([])
  19. plt.yticks([])
  20. if count[label] == 0:
  21. plt.title(label_name)
  22. count[label] += 1
  23. plt.show()

上述方法中,每个类别随机选取10个样本,按列排列,每一列代表一个类。可将从两个数据集提取的图片、标签以及对应类别送入其中,即可画出可视化的图片。

7、完整代码及可视化结果

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Aug 12 16:23:45 2020
  4. @author: LWS
  5. 从cifar10以及cifar100的序列化文件中,提取图片以及标签、文件名等信息
  6. """
  7. import os
  8. import pickle
  9. #import cv2
  10. from PIL import Image
  11. import numpy as np
  12. import matplotlib.pyplot as plt
  13. def load_labels_name(filename):
  14. """使用pickle反序列化labels文件,得到存储内容
  15. cifar10的label文件为“batches.meta”,cifar100则为“meta”
  16. 反序列化之后得到字典对象,可根据key取出相应内容
  17. """
  18. with open(filename, 'rb') as f:
  19. obj = pickle.load(f)
  20. return obj
  21. def load_data_cifar(filename, mode='cifar10'):
  22. """ load data and labels information from cifar10 and cifar100
  23. cifar10 keys(): dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
  24. cifar100 keys(): dict_keys([b'filenames', b'batch_label', b'fine_labels', b'coarse_labels', b'data'])
  25. """
  26. with open(filename,'rb') as f:
  27. dataset = pickle.load(f, encoding='bytes')
  28. if mode == 'cifar10':
  29. data = dataset[b'data']
  30. labels = dataset[b'labels']
  31. img_names = dataset[b'filenames']
  32. elif mode == 'cifar100':
  33. data = dataset[b'data']
  34. labels = dataset[b'fine_labels']
  35. img_names = dataset[b'filenames']
  36. else:
  37. print("mode should be in ['cifar10', 'cifar100']")
  38. return None, None, None
  39. return data, labels, img_names
  40. def load_cifar10(cifar10_path, mode = 'train'):
  41. if mode == "train":
  42. data_all = np.empty(shape=[0, 3072],dtype=np.uint8)
  43. labels_all = []
  44. img_names_all = []
  45. for i in range(1,6):
  46. filename = os.path.join(cifar10_path, 'data_batch_'+str(i)).replace('\\','/')
  47. print("Loading {}".format(filename))
  48. data, labels, img_names = load_data_cifar(filename, mode='cifar10')
  49. data_all = np.vstack((data_all, data))
  50. labels_all += labels
  51. img_names_all += img_names
  52. return data_all,labels_all,img_names_all
  53. elif mode == "test":
  54. filename = os.path.join(cifar10_path, 'test_batch').replace('\\','/')
  55. print("Loading {}".format(filename))
  56. return load_data_cifar(filename, mode='cifar10')
  57. def load_cifar100(cifar100_path, mode = 'train'):
  58. if mode == "train":
  59. filename = os.path.join(cifar100_path, 'train')
  60. print("Loading {}".format(filename))
  61. data, labels, img_names = load_data_cifar(filename, mode='cifar100')
  62. elif mode == "test":
  63. filename = os.path.join(cifar100_path, 'test')
  64. print("Loading {}".format(filename))
  65. data, labels, img_names = load_data_cifar(filename, mode='cifar100')
  66. else:
  67. print("mode should be in ['train', 'test']")
  68. return None, None, None
  69. return data, labels, img_names
  70. def to_pil(data):
  71. r = Image.fromarray(data[0])
  72. g = Image.fromarray(data[1])
  73. b = Image.fromarray(data[2])
  74. pil_img = Image.merge('RGB', (r,g,b))
  75. return pil_img
  76. def random_visualize(imgs, labels, label_names):
  77. figure = plt.figure(figsize=(len(label_names),10))
  78. idxs = list(range(len(imgs)))
  79. np.random.shuffle(idxs)
  80. count = [0]*len(label_names)
  81. for idx in idxs:
  82. label = labels[idx]
  83. if count[label]>=10:
  84. continue
  85. if sum(count)>10 * len(label_names):
  86. break
  87. img = to_pil(imgs[idx])
  88. label_name = label_names[label]
  89. subplot_idx = count[label] * len(label_names) + label + 1
  90. print(label, subplot_idx)
  91. plt.subplot(10,len(label_names), subplot_idx)
  92. plt.imshow(img)
  93. plt.xticks([])
  94. plt.yticks([])
  95. if count[label] == 0:
  96. plt.title(label_name)
  97. count[label] += 1
  98. plt.show()
  99. if __name__ == "__main__":
  100. # 修改为你的数据集存放路径
  101. cifar10_path = "../cifar10/cifar-10-batches-py"
  102. cifar100_path = "../cifar100/cifar-100-python"
  103. obj_cifar10 = load_labels_name(os.path.join(cifar10_path, 'batches.meta')) # label_names、num_cases_per_batch、num_vis
  104. obj_cifar100 = load_labels_name(os.path.join(cifar100_path, 'meta')) # coarse_label_names、fine_label_names
  105. # 提取cifar10、cifar100的图片数据、标签、文件名
  106. data_cifar10_train,labels_cifar10_train,img_names_cifar10_train = \
  107. load_cifar10(cifar10_path, mode='train')
  108. data_cifar10_test,labels_cifar10_test,img_names_cifar10_test = \
  109. load_cifar10(cifar10_path, mode='test')
  110. imgs_cifar10_train = data_cifar10_train.reshape(data_cifar10_train.shape[0],3,32,32)
  111. imgs_cifar10_test = data_cifar10_test.reshape(data_cifar10_test.shape[0],3,32,32)
  112. data_cifar100_train,labels_cifar100_train,img_names_cifar100_train = \
  113. load_cifar100(cifar100_path, mode = 'train')
  114. data_cifar100_test,labels_cifar100_test,img_names_cifar100_test = \
  115. load_cifar100(cifar100_path, mode = 'test')
  116. imgs_cifar100_train = data_cifar100_train.reshape(data_cifar100_train.shape[0],3,32,32)
  117. imgs_cifar100_test = data_cifar100_test.reshape(data_cifar100_test.shape[0],3,32,32)
  118. # visualize fro cifar10
  119. label_names_cifar10 = obj_cifar10['label_names']
  120. random_visualize(imgs=imgs_cifar10_train,
  121. labels=labels_cifar10_train,
  122. label_names=label_names_cifar10)
  123. # visualize fro cifar100
  124. label_names_cifar100 = obj_cifar100['fine_label_names']
  125. random_visualize(imgs=imgs_cifar100_train,
  126. labels=labels_cifar100_train,
  127. label_names=label_names_cifar100)

可视化结果分别如下:

CIFAR-10随机可视化结果

 

CIFAR-100随机可视化结果

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/504627
推荐阅读
相关标签
  

闽ICP备14008679号