当前位置:   article > 正文

PyTorch深度学习入门笔记(五)Transforms的使用_from torchvision import transforms

from torchvision import transforms

课程学习笔记,课程链接
学习笔记同步发布在我的个人网站上,欢迎来访查看。

一、Transforms的使用

torchvision中的transforms主要是对图片进行一些变换。
tranforms对应 tranforms.py 文件,里面定义了很多类,输入一个图片对象,返回经过处理的图片对象。
在这里插入图片描述
transforms.py就像一个工具箱,里面定义的各种类就像各种工具,图片就是输入对象,经过工具处理,输出期望的图片结果。
在这里插入图片描述
现在通过 transforms.ToTensor去看两个问题:

  • 1、transforms该如何使用(python)
  • 2、为什么我们需要 Tensor 数据类型
    在这里插入图片描述
    ToTensor功能是将 PIL Image 类型 或者numpy.ndarray类型的图片对象转换为 tensor类型。
    使用Demo:
from torchvision import transforms
from PIL import Image

img_path = "testdata/train/ants_image/6743948_2b8c096dda.jpg"
img = Image.open(img_path)
print(img)
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
print(tensor_img)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

在这里插入图片描述
所以使用transforms的方法就是 先实例化选中的类,然后用实例化的对象去处理图片就行。
在这里插入图片描述

二、Tensor数据类型

将第一节中的代码复制到 python 控制台,回车,可在右侧看到各种变量和对象的具体信息:
在这里插入图片描述
tensor 数据类型可以理解为包装了反向神经网络一些理论基础参数。在神经网络中,要将数据先转换为Tensor类型,再进行训练。
测试代码:

from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from PIL import Image

img_path = "testdata/train/ants_image/6743948_2b8c096dda.jpg"
img = Image.open(img_path)

writer = SummaryWriter("logs")

tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)

writer.add_image("Tensor_Image",tensor_img)
writer.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

结果:
在这里插入图片描述

三、常见的Transforms

常用的输入图片对象的数据类型

  • PIL : Image.open()
  • tensor : ToTensor()
  • ndarrays: cv.imread()

常用的Transform有:

  1. ToTensor() :将图片对象类型转为 tensor
  2. Normalize() :对图像像素进行归一化计算
  3. Resize():重新设置 PIL Image的大小,返回也是PIL Image格式
  4. Compose(): 输入为 transforms类型参数的列表,即
Compose([transforms参数1, transforms参数2], ...)
  • 1

目的是将几个 transforms操作打包成一个,比如要先进行大小调整,然后进行归一化计算,返回tensor类型,则可以将 ToTensor、Normalize、Resize,按操作顺序输入到Compose中。
示例代码:

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import os

root_path = "hymenoptera_data/train/ants"
img_name = "7759525_1363d24e88.jpg"
img_path = os.path.join(root_path,img_name)
img = Image.open(img_path)

writer = SummaryWriter("logs")

# ToTensor
trans_totensor = transforms.ToTensor() # instantiation
img_tensor = trans_totensor(img)
writer.add_image("Tensor", img_tensor)

# Normalize
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize", img_norm)

#Resize
print(img.size)
trans_resize = transforms.Resize((512,512))
img_resize = trans_resize(img) # return type still is PIL image
img_resize = trans_totensor(img_resize)
writer.add_image("Resize", img_resize)

# Compose - resize -2
trans_resize_2 = transforms.Resize(512)
tran_compose = transforms.Compose([trans_resize_2, trans_totensor])
img_resize2 = tran_compose(img)
writer.add_image("Compose", img_resize2)

writer.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

  • 关注输入和输出类型
  • 多看官方文档
  • 关注方法需要什么参数
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/351036
推荐阅读
相关标签
  

闽ICP备14008679号