赞
踩
数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用转换来对数据进行一些操作,并使其适合训练。
所有TorchVision数据集都有两个参数-transform用于修改功能和target_transform用于修改标签-接受包含转换逻辑的可调用。torchvision.transforms模块提供了几个常用的开箱即用的转换。
FashionMNIST功能采用PIL图像格式,标签为整数。为了训练,我们需要作为归一化张量的特征,以及作为单热编码张量的标签。为了进行这些转换,我们使用ToTensor和Lambda。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
ToTensor将PIL图像或NumPy ndarray转换为FloatTensor。并在[0.,1.]范围内缩放图像的像素强度值。
Torchvision支持torchvision.transforms和torchvision.transforms.v2模块中的常见计算机视觉转换。转换可用于转换或增强数据,以训练或推断不同任务(图像分类、检测、分割、视频分类)。
# Image Classification import torch from torchvision.transforms import v2 H, W = 32, 32 img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8) transforms = v2.Compose([ v2.RandomResizedCrop(size=(224, 224), antialias=True), v2.RandomHorizontalFlip(p=0.5), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) img = transforms(img)
# Detection (re-using imports and transforms from above)
from torchvision import tv_tensors
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
boxes = torch.randint(0, H // 2, size=(3, 4))
boxes[:, 2:] += boxes[:, :2]
boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
# The same transforms can be used!
img, boxes = transforms(img, boxes)
# And you can pass arbitrary input structures
output_dict = transforms({"image": img, "boxes": boxes})
Lambda变换应用任何用户定义的lambda函数。在这里,我们定义了一个函数,将整数转换为单热编码张量。它首先创建一个大小为10的零张量(我们数据集中的标签数量),并调用scatter_,在标签y给出的索引上分配一个value=1。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。