当前位置:   article > 正文

pytorch学习笔记(四):输入流水线(input pipeline)_# data loader (input pipeline什么意思

# data loader (input pipeline什么意思

input-pipeline

引包

from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
  • 1
  • 2
  • 3

图像预处理

# 创建个transform用来处理图像数据
transform = transforms.Compose([
    transforms.Scale(40),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

准备数据

# 下载数据
train_dataset = dsets.CIFAR10(root='./data/',
                               train=True,
                               transform=transform,#用了之前定义的transform
                               download=True)

image, label = train_dataset[0]
print (image.size())
print (label)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
Files already downloaded and verified
torch.Size([3, 32, 32])
6
  • 1
  • 2
  • 3

加载数据

# data loader提供了队列和线程
train_loader = data.DataLoader(dataset=train_dataset,
                               batch_size=100,# 这里定义了batch_size
                               shuffle=True,
                               num_workers=2)
  • 1
  • 2
  • 3
  • 4
  • 5
# 迭代开始,然后,队列和线程跟着也开始
data_iter = iter(train_loader)

# mini-batch 图像 和 标签
images, labels = next(data_iter)

for images, labels in train_loader:
    # 这里是训练代码
    pass
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/907797
推荐阅读
相关标签
  

闽ICP备14008679号