赞
踩
简而言之,就是训练的数据有时候并不是机器学习训练的数据格式,这个时候就需要 Transforms 对数据进行一些操作(转换),使其适合做神经网络的输入。比如对于图像数据,通常是一个三维度 Tensor (长、宽、channels),但是神经网络通常需要一个拉直成一维的 Tensor,这个时候就需要用到 Transforms 对三维 Tensor 进行拉直。
所有的 TorchVision 数据集都有两个参数,transform 用于修改特征,target_transform 用于修改标签,tochvision.transforms 模块提供了几种开箱即用的转换。
FashionMNIST 的特征是 PIL Image 格式的图片,标签是整数,对于训练来说,特征需要拉直成一维的 Tensor,而标签需要作为-独热编码 Tensor,为了进行这些转换,使用 ToTensor 和 Lambda。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root = "data",
train = True,
download = True,
transform = ToTensor(),
# https://blog.csdn.net/xinjieyuan/article/details/106672340
target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
Out:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gzExtracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gzExtracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
ToTensor 将 PIL Image 或者 NumPy ndarray 转换为 FloatTensor,并将图像像素缩放到[0, 1]
Lambda 接受一个用户自定义的 lambda 函数,在这个实例中,定义了一个函数将整数转换为一个独热编码张量,首先创建一个大小为 10 的零张量(数据集中标签的数量),并且调用 scatter_,他在标签 y 给出的索引上指定 value=1
。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。