当前位置:   article > 正文

PyTorch学习之Transforms模块_torch target_transform

torch target_transform

以下内容全部来自 Transforms

Ⅰ. Transforms

简而言之,就是训练的数据有时候并不是机器学习训练的数据格式,这个时候就需要 Transforms 对数据进行一些操作(转换),使其适合做神经网络的输入。比如对于图像数据,通常是一个三维度 Tensor (长、宽、channels),但是神经网络通常需要一个拉直成一维的 Tensor,这个时候就需要用到 Transforms 对三维 Tensor 进行拉直。

所有的 TorchVision 数据集都有两个参数,transform 用于修改特征,target_transform 用于修改标签,tochvision.transforms 模块提供了几种开箱即用的转换。

FashionMNIST 的特征是 PIL Image 格式的图片,标签是整数,对于训练来说,特征需要拉直成一维的 Tensor,而标签需要作为-独热编码 Tensor,为了进行这些转换,使用 ToTensor 和 Lambda。

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root = "data",
    train = True, 
    download = True,
    transform = ToTensor(),
    # https://blog.csdn.net/xinjieyuan/article/details/106672340
    target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

Out:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gzExtracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gzExtracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

Ⅱ. ToTensor()

ToTensor 将 PIL Image 或者 NumPy ndarray 转换为 FloatTensor,并将图像像素缩放到[0, 1]

Ⅲ. Lambda Transforms

Lambda 接受一个用户自定义的 lambda 函数,在这个实例中,定义了一个函数将整数转换为一个独热编码张量,首先创建一个大小为 10 的零张量(数据集中标签的数量),并且调用 scatter_,他在标签 y 给出的索引上指定 value=1

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号