当前位置:   article > 正文

torchvision.transforms 数据预处理:ToTensor()_totensor的作用

totensor的作用

ToTensor() 是pytorch中的数据预处理函数,包含在 torchvision.transforms 模块下。一般用于处理图像数据,所以其处理对象是 PIL Image 和 numpy.ndarray 。

1、ToTensor() 函数的作用

必须要声明不能只看函数名,就以为 ToTensor() 只是将图像转为 tensor。

ToTensor() 函数的源码:
 

  1. class ToTensor:
  2. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
  3. Converts a PIL Image or numpy.ndarray (H x W x C) in the range
  4. [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
  5. if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
  6. or if the numpy.ndarray has dtype = np.uint8
  7. In the other cases, tensors are returned without scaling.
  8. .. note::
  9. Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
  10. transforming target image masks. See the `references`_ for implementing the transforms for image masks.
  11. .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
  12. """

大意是:

(1)将 PIL Image 或 numpy.ndarray 转为 tensor

(2)如果 PIL Image 属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 中的一种图像类型,或者 numpy.ndarray 格式数据类型是 np.uint8 ,则将 [0, 255] 的数据转为 [0.0, 1.0] ,也就是说将所有数据除以 255 进行归一化。

(3)将 HWC 的图像格式转为 CHW 的 tensor 格式。CNN训练时需要的数据格式是[N,C,N,W],也就是说经过 ToTensor() 处理的图像可以直接输入到CNN网络中,不需要再进行reshape。
 

这是一个非常常用的转换。在PyTorch中,我们主要处理张量形式的数据。如果输入数据是NumPy数组或PIL图像的形式,我们可以使用ToTensor将其转换为张量格式。最后一个张量的形式是(C * H * W)。同时,还执行从0–255到0–1的范围内的缩放操作。

2、ToTensor() 的使用

(1)np.array 整型的默认数据类型为 np.int32,经过 ToTensor() 后数值不变,不进行归一化。
(2)np.array 整型的默认数据类型为 np.float64,经过 ToTensor() 后数值不变,不进行归一化。
(3)opencv 读取的图像格式为 np.array,其数据类型为 np.uint8。
    经过 ToTensor() 后数值由 [0,255] 变为 [0,1],通过将每个数据除以255进行归一化。
(4)经过 ToTensor() 后,HWC 的图像格式变为 CHW 的 tensor 格式。
(5)np.uint8 和 np.int8 不一样,uint8是无符号整型,数值都是正数。
(6)ToTensor() 可以处理任意 shape 的 np.array,并不只是三通道的图像数据。

(1) np.uint8 类型

  1. import numpy as np
  2. from torchvision import transforms
  3. data = np.array([
  4. [0, 5, 10, 20, 0],
  5. [255, 125, 180, 255, 196]
  6. ], dtype=np.uint8)
  7. tensor = transforms.ToTensor()(data)
  8. print(tensor)
  9. """
  10. tensor([[[0.0000, 0.0196, 0.0392, 0.0784, 0.0000],
  11. [1.0000, 0.4902, 0.7059, 1.0000, 0.7686]]])
  12. """

(2)非 np.uint8 类型

  1. import numpy as np
  2. from torchvision import transforms
  3. data = np.array([
  4. [0, 5, 10, 20, 0],
  5. [255, 125, 180, 255, 196]
  6. ]) # data.dtype = int32
  7. tensor = transforms.ToTensor()(data)
  8. print(tensor)
  9. """
  10. tensor([[[ 0, 5, 10, 20, 0],
  11. [255, 125, 180, 255, 196]]], dtype=torch.int32)
  12. """

原文链接:

torchvision.transforms 数据预处理:ToTensor()_ctrl A_ctrl C_ctrl V的博客-CSDN博客

用于记录,方便自己查阅

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/351038
推荐阅读
相关标签
  

闽ICP备14008679号