赞
踩
ToTensor() 是pytorch中的数据预处理函数,包含在 torchvision.transforms 模块下。一般用于处理图像数据,所以其处理对象是 PIL Image 和 numpy.ndarray 。
必须要声明不能只看函数名,就以为 ToTensor() 只是将图像转为 tensor。
ToTensor() 函数的源码:
- class ToTensor:
- """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
- Converts a PIL Image or numpy.ndarray (H x W x C) in the range
- [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
- if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
- or if the numpy.ndarray has dtype = np.uint8
- In the other cases, tensors are returned without scaling.
- .. note::
- Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
- transforming target image masks. See the `references`_ for implementing the transforms for image masks.
- .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
- """
大意是:
(1)将 PIL Image 或 numpy.ndarray 转为 tensor
(2)如果 PIL Image 属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 中的一种图像类型,或者 numpy.ndarray 格式数据类型是 np.uint8 ,则将 [0, 255] 的数据转为 [0.0, 1.0] ,也就是说将所有数据除以 255 进行归一化。
(3)将 HWC 的图像格式转为 CHW 的 tensor 格式。CNN训练时需要的数据格式是[N,C,N,W],也就是说经过 ToTensor() 处理的图像可以直接输入到CNN网络中,不需要再进行reshape。
这是一个非常常用的转换。在PyTorch中,我们主要处理张量形式的数据。如果输入数据是NumPy数组或PIL图像的形式,我们可以使用ToTensor将其转换为张量格式。最后一个张量的形式是(C * H * W)。同时,还执行从0–255到0–1的范围内的缩放操作。
(1)np.array 整型的默认数据类型为 np.int32,经过 ToTensor() 后数值不变,不进行归一化。
(2)np.array 整型的默认数据类型为 np.float64,经过 ToTensor() 后数值不变,不进行归一化。
(3)opencv 读取的图像格式为 np.array,其数据类型为 np.uint8。
经过 ToTensor() 后数值由 [0,255] 变为 [0,1],通过将每个数据除以255进行归一化。
(4)经过 ToTensor() 后,HWC 的图像格式变为 CHW 的 tensor 格式。
(5)np.uint8 和 np.int8 不一样,uint8是无符号整型,数值都是正数。
(6)ToTensor() 可以处理任意 shape 的 np.array,并不只是三通道的图像数据。
(1) np.uint8 类型
- import numpy as np
- from torchvision import transforms
-
- data = np.array([
- [0, 5, 10, 20, 0],
- [255, 125, 180, 255, 196]
- ], dtype=np.uint8)
-
- tensor = transforms.ToTensor()(data)
- print(tensor)
- """
- tensor([[[0.0000, 0.0196, 0.0392, 0.0784, 0.0000],
- [1.0000, 0.4902, 0.7059, 1.0000, 0.7686]]])
- """
(2)非 np.uint8 类型
- import numpy as np
- from torchvision import transforms
-
- data = np.array([
- [0, 5, 10, 20, 0],
- [255, 125, 180, 255, 196]
- ]) # data.dtype = int32
-
- tensor = transforms.ToTensor()(data)
- print(tensor)
- """
- tensor([[[ 0, 5, 10, 20, 0],
- [255, 125, 180, 255, 196]]], dtype=torch.int32)
- """
原文链接:
torchvision.transforms 数据预处理:ToTensor()_ctrl A_ctrl C_ctrl V的博客-CSDN博客
用于记录,方便自己查阅
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。