赞
踩
现在我们定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。 这个函数返回训练集和验证集的数据迭代器。 此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。
代码如下:
import torchvision from torchvision import transforms from torch.utils import data def get_dataloader_workers(): return 0 def load_data_fashion_mnist(batch_size, resize=None): #@save """下载Fashion-MNIST数据集,然后将其加载到内存中""" trans = [transforms.ToTensor()] if resize: trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) mnist_train = torchvision.datasets.FashionMNIST( root="./15.动手学深度学习代码手撸/data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST( root="./15.动手学深度学习代码手撸/data", train=False, transform=trans, download=True) return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
下面,我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。
train_iter, test_iter = load_data_fashion_mnist(32, resize = 64)
for X, y in train_iter:
print(X.shape, y.shape)
break
我们的输出结果为:
torch.Size([32, 1, 64, 64]) torch.Size([32])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。