当前位置:   article > 正文

pytorch中的transforms介绍_torch transform

torch transform

transforms用法介绍

torchvision.transforms模块主要用于对图像进行转换等一系列预处理操作,其主要目的是对图像数据进行增强,进而提高模型的泛化能力。对图像预处理操作有数据中心化,缩放,裁剪,旋转,翻转,填充,添加噪声,灰度变换,线性变换,仿射变换,亮度,饱和度,对比变换等。

transforms.Compose

transforms.Compose是将一系列的图像转换函数进行组合,实现时能够按照这些函数的顺序依次去图像进行处理操作,需要注意的是同样的功能也可以用torch.nn.Sequential函数来实现。

CLASS torchvision.transforms.Compose(transforms)

  • transforms:表示图像变换组合的列

transforms.Compose具体实例的代码如下所示

transform_train = transforms.Compose([
    transforms.RandomCrop(cut_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
transform_train = torch.nn.Sequential(
    transforms.RandomCrop(cut_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), 
    )

transform_test = transforms.Compose([
    transforms.TenCrop(cut_size),
    transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
])
transform_test =  torch.nn.Sequential(
    transforms.TenCrop(cut_size),
    transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
 )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

transforms.ToTensor

transforms.ToTensor的作用是将一个PIL Image格式的图片或者是取值范围为 [ 0 , 255 ] [0,255] [0,255],形状为 [ H × W × C ] [\mathrm{H} \times \mathrm{W} \times \mathrm{C}] [H×W×C]numpy.ndarray的数组转换为取值范围为 [ 0.0 , 1.0 ] [0.0,1.0] [0.0,1.0],形状为 [ C × H × W ] [\mathrm{C}\times \mathrm{H}\times \mathrm{W}] [C×H×W]tensor格式图片。

transforms.RandomCrop

transforms.RandomCrop的作用是在图片的随机位置上进行裁剪并返回新的图片。

CLASS torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode=‘constant’)

  • size:表示裁剪图片的输出尺寸,如果参数是一个整数则裁剪的是一个正方形
  • padding:表示图像每个边框上的可选填充。默认值是None
  • pad_if_needed:如果图像小于所需大小,它将填充图像,以避免引发异常
  • fill:表示像素填充值,默认值为0。如果元组长度为3,则用于分别填充R、G、B通道
  • padding_mode:表示像素填充值的类型,默认是常值,也有边缘填充,反射和对称

transforms.RandomCrop具体实例的代码实现和对应的可视化图如下所示

from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.io import read_image

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)

def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

img1 = read_image(str(Path('assets') / 'KOBE1.png'))
img2 = read_image(str(Path('assets') / 'KOBE2.png'))
show([img1, img2])

transforms = T.Compose([
		T.RandomCrop(224),
	])

# transforms = torch.nn.Sequential(
#     T.RandomCrop(224),
# )

device = 'cuda' if torch.cuda.is_available() else 'cpu'
img1 = img1.to(device)
img2 = img2.to(device)

transformed_img1 = transforms(img1)
transformed_img2 = transforms(img2)
show([transformed_img1, transformed_img2])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

transforms.RandomHorizontalFlip

transforms.RandomHorizontalFlip的作用是以特定的概率将图片进行水平翻转。

CLASS torchvision.transforms.RandomHorizontalFlip(p=0.5)

  • p:表示图片水平翻转的概率,默认值是0.5

transforms.RandomHorizontalFlip具体实例的代码实现和对应的可视化图如下所示

from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.io import read_image

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)

def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = T.ToPILImage()(img.to('cpu'))
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

img1 = read_image(str(Path('assets') / 'KOBE1.png'))  # type : torch
img2 = read_image(str(Path('assets') / 'KOBE2.png'))  # type : torch
show([img1, img2])


transforms = T.Compose([
		 T.RandomHorizontalFlip(p=0.9),
	])

# transforms = torch.nn.Sequential(
#      T.RandomHorizontalFlip(p=0.3),
# )

device = 'cuda' if torch.cuda.is_available() else 'cpu'
img1 = img1.to(device)
img2 = img2.to(device)

transformed_img1 = transforms(img1)
transformed_img2 = transforms(img2)
show([transformed_img1, transformed_img2])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

transforms.TenCrop

transforms.RandomCrop的作用是可以将一张图片的四个角和中心进行裁剪后,然后加上返回的翻转后共10张图片,其中默认翻转是水平翻转。

CLASS torchvision.transforms.TenCrop(size, vertical_flip=False)

  • size:表示裁剪图片的输出尺寸,如果参数是一个整数则裁剪的是一个正方形
  • vertical_flip:表示图片是否用垂直翻转代替水平翻转None

需要注意的是transforms.TenCrop函数的输入必须是 P I L \mathrm{PIL} PIL的图片格式,具体实例的代码实现和对应的可视化图如下所示

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'

torch.manual_seed(0)


def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
    plt.show()

orig_img = Image.open(Path('assets') / 'KOBE1.png')   # tyep : PIL
(top_left, top_right, bottom_left, bottom_right, center, 
    flip_top_left, flip_top_right, flip_bottom_left, flip_bottom_right, flip_center) = T.TenCrop(size=(200,200))(orig_img)
plot([[top_left, top_right, bottom_left, bottom_right, center], 
        [flip_top_left, flip_top_right, flip_bottom_left, flip_bottom_right, flip_center]],with_orig=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

格外需要注意transforms.TenCrop对于每张图片会返回10张变换后的图片,尤其是在测试阶段会导致图片数量和标签数量不匹配,可以进行如下处理

transform = Compose([
	   FiveCrop(size), # this is a list of PIL Images
       Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
	 ])
#In your test loop you can do the following:
input, target = batch # input is a 5d tensor, target is 2d
bs, ncrops, c, h, w = input.size()
result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/351006
推荐阅读
相关标签
  

闽ICP备14008679号