当前位置:   article > 正文

你需要知道的11个Torchvision计算机视觉数据集

svhn数据集介和cifar10数据集

7f9a86014f4c66714535b67c7ec6329b.png

来源:新机器视觉

本文约3800字,建议阅读8分钟

本文介绍了11个Torchvision计算机视觉数据集。

计算机视觉是一个显著增长的领域,有许多实际应用,从自动驾驶汽车到面部识别系统。该领域的主要挑战之一是获得高质量的数据集来训练机器学习模型。

Torchvision作为Pytorch的图形库,一直服务于PyTorch深度学习框架,主要用于构建计算机视觉模型。


为了解决这一挑战,Torchvision提供了访问预先构建的数据集、模型和专门为计算机视觉任务设计的转换。此外,Torchvision还支持CPU和GPU的加速,使其成为开发计算机视觉应用程序的灵活且强大的工具。

01 什么是“Torchvision数据集”?


Torchvision数据集是计算机视觉中常用的用于开发和测试机器学习模型的流行数据集集合。运用Torchvision数据集,开发人员可以在一系列任务上训练和测试他们的机器学习模型,例如,图像分类、对象检测和分割。数据集还经过预处理、标记并组织成易于加载和使用的格式。


据了解,Torchvision包由流行的数据集、模型体系结构和通用的计算机视觉图像转换组成。简单地说就是“常用数据集+常见模型+常见图像增强”方法。


Torchvision中的数据集共有11种:MNIST、CIFAR-10等,下面具体说说。

02 Torchvision中的11种数据集

1、MNIST手写数字数据库


这个Torchvision数据集在机器学习和计算机视觉领域中非常流行和广泛应用。它由7万张手写数字0-9的灰度图像组成。其中,6万张用于训练,1万张用于测试。每张图像的大小为28×28像素,并有相应的标签表示它所代表的数字。


要访问此数据集,您可以直接从Kaggle下载或使用torchvision加载数据集:

  1. import torchvision.datasets as datasets# Load the training dataset
  2. train_dataset = datasets.MNIST(root='data/', train=True, transform=None, download=True)# Load the testing dataset
  3. test_dataset = datasets.MNIST(root='data/', train=False, transform=None, download=True)

左右滑动查看完整代码

622c74fe283873f0c8cd53e492435270.jpeg

2、CIFAR-10(广泛使用的标准数据集)


CIFAR-10数据集由6万张32×32彩色图像组成,分为10个类别,每个类别有6000张图像,总共有5万张训练图像和1万张测试图像。这些图像又分为5个训练批次和一个测试批次,每个批次有1万张图像。数据集可以从Kaggle下载。

  1. import torchimport torchvisionimport torchvision.transforms as transforms
  2. transform = transforms.Compose(
  3. [transforms.ToTensor(),
  4. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  5. trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
  6. download=True, transform=transform)
  7. testset = torchvision.datasets.CIFAR10(root='./data', train=False,
  8. download=True, transform=transform)
  9. trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
  10. shuffle=True, num_workers=2)
  11. testloader = torch.utils.data.DataLoader(testset, batch_size=4,
  12. shuffle=False, num_workers=2)

左右滑动查看完整代码

在此提醒一句,您可以根据需要调整数据加载器的批处理大小和工作进程的数量。

3、CIFAR-100(广泛使用的标准数据集)


CIFAR-100数据集在100个类中有60,000张(50,000张训练图像和10,000张测试图像)32×32的彩色图像。每个类有600张图像。这100个类被分成20个超类,用一个细标签表示它的类,另一个粗标签表示它所属的超类。

  1. import torchimport torchvisionimport torchvision.transforms as transforms
  2. import torchvision.datasets as datasetsimport torchvision.transforms as transforms# Define transform to normalize data
  3. transform = transforms.Compose([
  4. transforms.ToTensor(),
  5. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
  6. # Load CIFAR-100 train and test datasets
  7. trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
  8. testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
  9. # Create data loaders for train and test datasets
  10. trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
  11. testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

左右滑动查看完整代码

4、ImageNet数据集


Torchvision中的ImageNet数据集包含大约120万张训练图像,5万张验证图像和10万张测试图像。数据集中的每张图像都被标记为1000个类别中的一个,如“猫”、“狗”、“汽车”、“飞机”等。

  1. import torchvision.datasets as datasetsimport torchvision.transforms as transforms
  2. # Set the path to the ImageNet dataset on your machine
  3. data_path = "/path/to/imagenet"
  4. # Create the ImageNet dataset object with custom options
  5. imagenet_train = datasets.ImageNet(
  6. root=data_path,
  7. split='train',
  8. transform=transforms.Compose([
  9. transforms.Resize(256),
  10. transforms.RandomCrop(224),
  11. transforms.RandomHorizontalFlip(),
  12. transforms.ToTensor(),
  13. transforms.Normalize(
  14. mean=[0.485, 0.456, 0.406],
  15. std=[0.229, 0.224, 0.225])
  16. ]),
  17. download=False)
  18. imagenet_val = datasets.ImageNet(
  19. root=data_path,
  20. split='val',
  21. transform=transforms.Compose([
  22. transforms.Resize(256),
  23. transforms.CenterCrop(224),
  24. transforms.ToTensor(),
  25. transforms.Normalize(
  26. mean=[0.485, 0.456, 0.406],
  27. std=[0.229, 0.224, 0.225])
  28. ]),
  29. download=False)
  30. # Print the number of images in the training and validation setsprint("Number of images in the training set:", len(imagenet_train))print("Number of images in the validation set:", len(imagenet_val))

左右滑动查看完整代码

a5e552a0b675d4b8c666226b539723f2.jpeg

5、MSCoco数据集


Microsoft Common Objects in Context(MS Coco)数据集包含32.8万张日常物体和人类的高质量视觉图像,通常用作实时物体检测中比较算法性能的标准。

6、Fashion-MNIST数据集


时尚MNIST数据集是由Zalando Research创建的,作为原始MNIST数据集的替代品。Fashion MNIST数据集由70000张服装灰度图像(训练集60000张,测试集10000张)组成。


图片大小为28×28像素,代表10种不同类别的服装,包括:t恤/上衣、裤子、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包和短靴。它类似于原始的MNIST数据集,但由于服装项目的复杂性和多样性,分类任务更具挑战性。这个Torchvision数据集可以从Kaggle下载。

  1. import torchimport torchvisionimport torchvision.transforms as transforms
  2. # Define transformations
  3. transform = transforms.Compose(
  4. [transforms.ToTensor(),
  5. transforms.Normalize((0.5,), (0.5,))])# Load the dataset
  6. trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
  7. download=True, transform=transform)
  8. testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
  9. download=True, transform=transform)
  10. # Create data loaders
  11. trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
  12. shuffle=True, num_workers=2)
  13. testloader = torch.utils.data.DataLoader(testset, batch_size=4,
  14. shuffle=False, num_workers=2)

左右滑动查看完整代码

7、SVHN数据集


SVHN(街景门牌号)数据集是一个来自谷歌街景图像的图像数据集,它由从街道级图像中截取的门牌号的裁剪图像组成。它包含所有门牌号及其包围框的完整格式和仅包含门牌号的裁剪格式。完整格式通常用于对象检测任务,而裁剪格式通常用于分类任务。


SVHN数据集也包含在Torchvision包中,它包含了73,257张用于训练的图像、26,032张用于测试的图像和531,131张用于额外训练数据的额外图像。

  1. import torchvisionimport torch
  2. # Load the train and test sets
  3. train_set = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=torchvision.transforms.ToTensor())
  4. test_set = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=torchvision.transforms.ToTensor())
  5. # Create data loaders
  6. train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
  7. test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)

左右滑动查看完整代码

8、STL-10数据集


STL-10数据集是一个图像识别数据集,由10个类组成,总共约6000 +张图像。STL-10代表“图像识别标准训练和测试集-10类”,数据集中的10个类是:飞机、鸟、汽车、猫、鹿、狗、马、猴子、船、卡车。您可以直接从Kaggle下载数据集。

  1. import torchvision.datasets as datasetsimport torchvision.transforms as transforms
  2. # Define the transformation to apply to the data
  3. transform = transforms.Compose([
  4. transforms.ToTensor(),
  5. # Convert PIL image to PyTorch tensor
  6. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the data])
  7. # Load the STL-10 dataset
  8. train_dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)
  9. test_dataset = datasets.STL10(root='./data', split='test', download=True, transform=transform)

左右滑动查看完整代码

9、CelebA数据集


这个Torchvision数据集是一个流行的大规模面部属性数据集,包含超过20万张名人图像。2015年,香港中文大学的研究人员首次发布了这一数据。CelebA中的图像包含40个面部属性,如,年龄、头发颜色、面部表情和性别。


此外,这些图片是从互联网上检索到的,涵盖了广泛的面部外观,包括不同的种族、年龄和性别。每个图像中面部位置的边界框注释,以及眼睛、鼻子和嘴巴的5个地标点。

  1. import torchvision.datasets as datasetsimport torchvision.transforms as transforms
  2. transform = transforms.Compose([
  3. transforms.CenterCrop(178),
  4. transforms.Resize(128),
  5. transforms.ToTensor(),
  6. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  7. celeba_dataset = datasets.CelebA(root='./data', split='train', transform=transform, download=True)

左右滑动查看完整代码

10、PASCAL VOC数据集


VOC数据集(视觉对象类)于2005年作为PASCAL VOC挑战的一部分首次引入。该挑战旨在推进视觉识别的最新水平。它由20种不同类别的物体组成,包括:动物、交通工具和常见的家用物品。这些图像中的每一个都标注了图像中物体的位置和分类。注释包括边界框和像素级分割掩码。


数据集分为两个主要集:训练集和验证集。


训练集包含大约5000张带有注释的图像,而验证集包含大约5000张没有注释的图像。此外,该数据集还包括一个包含大约10,000张图像的测试集,但该测试集的注释是不可公开的。

  1. import torchimport torchvisionfrom torchvision import transforms
  2. # Define transformations to apply to the images
  3. transform = transforms.Compose([
  4. transforms.Resize((224, 224)),
  5. transforms.ToTensor(),
  6. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
  7. # Load the train and validation datasets
  8. train_dataset = torchvision.datasets.VOCDetection(root='./data', year='2007', image_set='train', transform=transform)
  9. val_dataset = torchvision.datasets.VOCDetection(root='./data', year='2007', image_set='val', transform=transform)# Create data loaders
  10. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
  11. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

左右滑动查看完整代码

11、Places365数据集


Places365数据集是一个大型场景识别数据集,拥有超过180万张图像,涵盖365个场景类别。Places365标准数据集包含约180万张图像,而Places365挑战数据集包含5万张额外的验证图像,这些图像对识别模型更具挑战性。

  1. import torchimport torchvisionfrom torchvision import transforms
  2. # Define transformations to apply to the images
  3. transform = transforms.Compose([
  4. transforms.Resize((224, 224)),
  5. transforms.ToTensor(),
  6. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
  7. # Load the train and validation datasets
  8. train_dataset = torchvision.datasets.Places365(root='./data', split='train-standard', transform=transform)
  9. val_dataset = torchvision.datasets.Places365(root='./data', split='val', transform=transform)# Create data loaders
  10. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
  11. val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

左右滑动查看完整代码

0440ba739dca621fee2bca708d97a3ce.jpeg

03 总结


总之,Torchvision数据集通常用于训练和评估机器学习模型,如卷积神经网络(CNNs)。这些模型通常用于计算机视觉应用,任何人都可以免费下载和使用。本文的主要图像是通过HackerNoon的AI稳定扩散模型生成的。

参考链接:

https://hackernoon.com/11-torchvision-datasets-for-computer-vision-you-need-to-know

编辑:王菁

校对:林亦霖

1f8153dd37674d548c75c699e6c891b9.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/948273
推荐阅读
相关标签
  

闽ICP备14008679号