当前位置:   article > 正文

Pytorch之Dataloader(数据集下载及分批次导入)_pydataloader下载

pydataloader下载

首先是加载标准数据集torchvision可以直接加载的数据集

建议去官网找doc>torchvision>dataset然后点击右面的数据集又详细的使用教程

比如这里的CIFAR10

当然,也可以直接ctrl+Q查看如何加载

加载代码如下

  1. import torchvision
  2. train_set = torchvision.datasets.CIFAR10(root='./CIFAR10', train=True, download=True)
  3. test_set = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False,download=True)
  4. img,target = test_set[0]
  5. print(img.shape)
  6. print(target)

好的数据集已经加载了,现在如何接住这些数据呢?
使用dataloader把数据加载成一批一批的

from torch.utils.data import DataLoader

但是,由于加载的数据需要是tensor,先修改一下之前的代码,类似下面这样加上ToTensor的变换

  1. # 测试集
  2. test_set = torchvision.datasets.CIFAR10(root='./CIFAR10',train=False, transform=torchvision.transforms.ToTensor(), download=False)

然后使用dataloader分批数据,每批64个.随机,不舍弃最后一组不完整的批

test_loader = DataLoader(dataset=test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

最后打开tensorboard搞个循环测试一下

  1. writer = SummaryWriter('logs')
  2. step = 0
  3. for data in test_loader:
  4. imgs,target = data
  5. print(img.shape)
  6. print(target)
  7. writer.add_images("test_data",imgs, step)
  8. step = step+1
  9. writer.close()

打开tensorboard会显示每批加载的数据

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/455048
推荐阅读
相关标签
  

闽ICP备14008679号