赞
踩
torch.utils.data.DataLoader(dataset,
batch_size=1, 每次拿多少数据
shuffle=None, 是否打乱
sampler=None,
batch_sampler=None,
num_workers=0, 多进程(加载数据时采用)默认是0,使用主进程加载数据
collate_fn=None,
pin_memory=False,
drop_last=False, 总数据/batch_size之后的余数是舍去(True)还是不舍去(False)
timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')
官方描述
2. 使用
from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import torchvision dataset_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) # 测试数据集 test_set = torchvision.datasets.CIFAR10(root="./dataset",transform=dataset_transform,train=False,download=True) # 测试集中的第一张图片及target img,target = test_set[0] print(img.shape,target) # 在终端中查看tensorboard --logdir=dataloader writer = SummaryWriter("dataloader") test_loader = DataLoader(dataset=test_set,batch_size=64,shuffle=True,num_workers=0,drop_last=True) for epoch in range(2): step = 0 for data in test_loader: imgs,targets = data # print(imgs.shape) # print(targets) writer.add_images("Epoch:{}".format(epoch),imgs,step) step += 1 writer.close()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。