当前位置:   article > 正文

Pytorch学习 day06(torchvision中的datasets、dataloader)

Pytorch学习 day06(torchvision中的datasets、dataloader)

torchvision的datasets

  • 使用torchvision提供的数据集API,比较方便,
  • 如果在pycharm中下载很慢,可以URL链接到迅雷中进行下载(有些URL链接在源码里)
  • 用来告诉程序,数据集存储的位置,共有多少样本等
  • 代码如下:
import torchvision  # 导入 torchvision 库
# 使用torchvision的datasets模块,模块中包含CIFAR10、CIFAR100、ImageNet、COCO等数据集
train_set = torchvision.datasets.CIFAR10("./Dataset", train = True, download = True)    # root 表示数据集的存储路径,train 表示是否是训练集,download 表示是否需要下载
test_set = torchvision.datasets.CIFAR10("./Dataset", train = False, download = True)
  • 1
  • 2
  • 3
  • 4
  • CIFAR10数据集的每个样本会输出一个元组,第一个元素是PIL格式的图片,第二个元素是该样本的标签,即class,代码如下:
import torchvision  # 导入 torchvision 库
# 使用torchvision的datasets模块,模块中包含CIFAR10、CIFAR100、ImageNet、COCO等数据集
train_set = torchvision.datasets.CIFAR10("./Dataset", train = True, download = True)    # root 表示数据集的存储路径,train 表示是否是训练集,download 表示是否需要下载
test_set = torchvision.datasets.CIFAR10("./Dataset", train = False, download = True)


print(train_set[0])  # 输出训练集的第一个样本 ,输出为一个元组,第一个元素为PIL格式图片,第二个元素为标签,标签表示图片的类别,即class
print(train_set.classes) # 输出数据集的类别,即class
img, target = train_set[0]
print(img)  # 输出图片
print(target)  # 输出标签
print(train_set.classes[target])  # 输出训练集第一个样本图片的类别
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 对数据集进行transforms变换
    • 注意,只需要在调用数据集API时,填入变换对象、或变换序列即可,由于dataset_transforms是Compose类实例化后的对象,所以直接传入即可,代码如下:
import torchvision  # 导入 torchvision 库
from torch.utils.tensorboard import SummaryWriter

dataset_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),  # 将PIL格式图片转换为Tensor格式
])  # Compose函数将多个transforms组合在一起


# 使用torchvision的datasets模块,模块中包含CIFAR10、CIFAR100、ImageNet、COCO等数据集
train_set = torchvision.datasets.CIFAR10("./Dataset", train = True, transform=dataset_transforms, download = True)    # root 表示数据集的存储路径,train 表示是否是训练集,transform 表示对数据集进行的变换,download 表示是否下载数据集
test_set = torchvision.datasets.CIFAR10("./Dataset", train = False, transform=dataset_transforms, download = True)

writer = SummaryWriter("logs")  # 实例化SummaryWriter类,参数log_dir表示日志文件的存储路径
for i in range(10):
    img, target = train_set[i]  
    writer.add_image("train_set_img", img, i) # 将图片写入tensorboard
    
writer.close()  # 关闭SummaryWriter对象
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • tensorboard的展示结果如下:
    在这里插入图片描述

torchvision中的dataloader

  • datasets的加载器,把数据加载到神经网络中,需要手动设置某些参数,如下:
    • dataset:就是上一节自定义的datasets对象
    • batch_size:每次从数据集中取多少个样本,并打包输入进神经网络
    • shuffle:每轮epoch样本抽取完毕后,需不需要打乱数据集,True–需要,False–不需要
    • num_workers:加载数据集时,采用多少进程来进行加载,默认为0,采用主进程来进行加载
    • drop_last:最后一次抽取样本时,如果不够一个batch_size,剩余的样本是否舍弃,True–舍弃,False–不舍弃
  • 由于dataloader中的sampler默认为RandomSampler随机采样,所以dataloader在每个batch_size中都是以随机策略在数据集中抓取的,如下:
    在这里插入图片描述
  • 输出结果:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

test_data = datasets.CIFAR10(root="./Dataset", train=False, transform=transforms.ToTensor(), download=True)

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

# 测试集中的第一个样本图像和标签
img, target = test_data[0]
print(img.shape)
print(target)

# 测试集中的第一个batch的图像和标签
for data in test_loader:
    img, target = data
    print(img.shape)
    print(target)
    break

# 输出结果:
# 第一次:
# Files already downloaded and verified
# torch.Size([3, 32, 32])
# 3
# torch.Size([4, 3, 32, 32])
# tensor([1, 5, 9, 9])

# 第二次:
# Files already downloaded and verified
# torch.Size([3, 32, 32])
# 3
# torch.Size([4, 3, 32, 32])
# tensor([5, 9, 0, 5])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 由于dataset中的_ _ getitem _ _方法返回img、target,所以如果batch_size为4,那么dataloader会以4个为一组,分别打包img和target,并返回imgs和targets,如下图:
    在这里插入图片描述
  • 且返回的imgs和targets都是Tensor数据类型,如下:
    在这里插入图片描述
  • 同时,由于imgs为Tensor数据类型,且满足(N, C, H, W)的形式,所以可以直接采用tensorboard进行展示输出,如下:
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

test_data = datasets.CIFAR10(root="./Dataset", train=False, transform=transforms.ToTensor(), download=True)

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

# 测试集中的第一个样本图像和标签
# img, target = test_data[0]
# print(img.shape)
# print(target)

# 测试集中的第一个batch的图像和标签
writer = SummaryWriter("logs")
step = 0
for data in test_loader:
    img, target = data
    print(img.shape)
    print(target)
    writer.add_images("test_loadimages", img, step) # 因为img是一个batch的图像,所以要用add_images
    step += 1

writer.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 展示结果:
    在这里插入图片描述
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/219378?site
推荐阅读
相关标签
  

闽ICP备14008679号