赞
踩
比赛数据集分为训练集和测试集,其中训练集包含50000张、测试集包含300000张图像。 在测试集中,10000张图像将被用于评估,而剩下的290000张图像将不会被进行评估,包含它们只是为了防止手动标记测试集并提交标记结果。 两个数据集中的图像都是png格式,高度和宽度均为32像素并有三个颜色通道(RGB)。 这些图片共涵盖10个类别:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。
比赛网址为:https://www.kaggle.com/c/cifar-10
在…/data中解压下载的文件并在其中解压缩train.7z和test.7z后的结构如下,train和test文件夹分别包含训练和测试图像,trainLabels.csv含有训练图像的标签, sample_submission.csv是提交文件的范例:
…/data/cifar-10/train/[1-50000].png
…/data/cifar-10/test/[1-300000].png
…/data/cifar-10/trainLabels.csv
…/data/cifar-10/sampleSubmission.csv
为了便于入门,我们提供包含前1000个训练图像和5个随机测试图像的数据集的小规模样本。 要使用Kaggle竞赛的完整数据集,你需要将以下demo变量设置为False,并且把Kaggle上面数据集下载下来放到指定文件夹下面。
import collections import math import os.path import shutil import d2l.torch import torch import torchvision.transforms import torch.utils.data from torch import nn import pandas as pd d2l.torch.DATA_HUB['cifar10_tiny'] = (d2l.torch.DATA_URL+'kaggle_cifar10_tiny.zip','2068874e4b9a9f0fb07ebe0ad2b29754449ccacd') demo = True if demo: data_dir = d2l.torch.download_extract('cifar10_tiny') else: data_dir = '../data/cifar-10'
def read_csv_data(fname):
with open(fname,'r') as f:
lines = f.readlines()[1:]
#读取文件中每一行数据,获取数据集的标签和对应的图片索引号,需要去除第一行标题名称
tokens = [line.rstrip().split(',') for line in lines]
return dict(((name,label) for name,label in tokens))
labels = read_csv_data(os.path.join(data_dir,'trainLabels.csv'))
print('样本数:',len(labels))
print('类别数:',len(set(labels.values())))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。