赞
踩
['t-shirt', 'toruser', 'pullover', 'dress', 'coat', 'sandal', 'shirt' ,'sneaker', 'bag', 'ankle boots']
部分服饰为:
导入相应的库,并下载数据集
%matplotlib inline import torch from IPython import display import torchvision # torchvision是关于图像操作的一些方便工具库,对于计算机视觉进行实现的一个库 from torch.utils import data # 用来读取数据 from torchvision import transforms # 为pytorch中图像预处理包,包含了很多种对图像进行变化的函数 from d2l import torch as d2l import matplotlib.pyplot as plt import time def use_svg_display(): # 用矢量图显示图片 display.set_matplotlib_formats('svg') # format格式 use_svg_display() # 用svg显示图片,这样图片的清晰度会更高 # 下载数据集 trans = transforms.ToTensor() # 把shape为(x, y, z)的转换为(z, x, y),并每个元素除以255 # 得到每个元素的数值均在0到1之间 mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)
数据集的探索
len(mnist_train), len(mnist_test)
# answer (60000, 10000) 训练集60000张,测试集10000张
mnist_train[0][0].shape
# torch.Size([1, 28, 28]) 单张图片的通道数和尺寸
数据集的可视化,结果为简介中的图片
def get_fashion_mnist_labels(labels): """返回Fashion-MNIST数据集的文本标签。""" test_labels = ['t-shirt', 'toruser', 'pullover', 'dress', 'coat', 'sandal', 'shirt' ,'sneaker', 'bag', 'ankle boots'] return [test_labels[int(i)] for i in labels] def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): # 该函数还未研究 """Plot a list of images.""" figsize = (num_cols * scale, num_rows * scale) _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if torch.is_tensor(img): ax.imshow(img.numpy()) else: ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) return axes X, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) images = show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y)) images plt.savefig('部分服饰.png', facecolor='white', edgecolor='red') # 生成图片的保存
把数据集通过函数形式导入到内存中
def load_data_fashion_mnist(batch_size, resize=None):
"""加载Fashion-MNIST数据集到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))# 把图片放大成resize * resize大小
trans = transforms.Compose(trans) # 串联多个图片变换的操作
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans)
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans)
return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
解释两个参数的含义:
batch_size:我们一次读取多少张图片
resize:是否要对图片进行等比例的放大或缩小。eg: resize=66,则图片的尺寸变为66 x 66
train_iter, test_iter = load_data_fashion_mnist(8, 12)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
结果为:torch.Size([8, 1, 12, 12]) torch.float32 torch.Size([8]) torch.int64
说明:我们一次读取8张图片,每张图片为单通道,尺寸为12 x 12,并且每张图片都有对应的标签,一共8个标签。
for X, y in test_iter:
print(X[0].tolist(), y[0])
break
结果为:
[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003921568859368563, 0.003921568859368563], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.019607843831181526, 0.1411764770746231, 0.019607843831181526, 0.003921568859368563, 0.10196078568696976, 0.062745101749897], [0.0, 0.0, 0.0, 0.0, 0.0, 0.007843137718737125, 0.18431372940540314, 0.4745098054409027, 0.4745098054409027, 0.43921568989753723, 0.47058823704719543, 0.11372549086809158], [0.0, 0.0, 0.0, 0.003921568859368563, 0.003921568859368563, 0.125490203499794, 0.38823530077934265, 0.5333333611488342, 0.6039215922355652, 0.6352941393852234, 0.5803921818733215, 0.1921568661928177], [0.0, 0.003921568859368563, 0.003921568859368563, 0.03529411926865578, 0.14901961386203766, 0.3803921639919281, 0.4588235318660736, 0.5607843399047852, 0.5921568870544434, 0.6117647290229797, 0.5921568870544434, 0.3843137323856354], [0.08235294371843338, 0.1921568661928177, 0.26274511218070984, 0.3607843220233917, 0.4431372582912445, 0.4745098054409027, 0.5254902243614197, 0.5764706134796143, 0.6078431606292725, 0.6078431606292725, 0.6196078658103943, 0.5176470875740051], [0.33725491166114807, 0.47058823704719543, 0.5058823823928833, 0.49803921580314636, 0.5137255191802979, 0.5647059082984924, 0.6078431606292725, 0.6392157077789307, 0.6941176652908325, 0.800000011920929, 0.7686274647712708, 0.5333333611488342], [0.0470588244497776, 0.12156862765550613, 0.24313725531101227, 0.30588236451148987, 0.32156863808631897, 0.3176470696926117, 0.2235294133424759, 0.11764705926179886, 0.20392157137393951, 0.35686275362968445, 0.3176470696926117, 0.20000000298023224], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]
对应的标签为:tensor(9),说明为第9种类型的服饰
完整代码链接:FashionMNIST数据集
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。