赞
踩
torchvision.transforms模块主要用于对图像进行转换等一系列预处理操作,其主要目的是对图像数据进行增强,进而提高模型的泛化能力。对图像预处理操作有数据中心化,缩放,裁剪,旋转,翻转,填充,添加噪声,灰度变换,线性变换,仿射变换,亮度,饱和度,对比变换等。
transforms.Compose是将一系列的图像转换函数进行组合,实现时能够按照这些函数的顺序依次去图像进行处理操作,需要注意的是同样的功能也可以用torch.nn.Sequential函数来实现。
CLASS torchvision.transforms.Compose(transforms)
- transforms:表示图像变换组合的列
transforms.Compose具体实例的代码如下所示
transform_train = transforms.Compose([
transforms.RandomCrop(cut_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
transform_train = torch.nn.Sequential(
transforms.RandomCrop(cut_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
)
transform_test = transforms.Compose([
transforms.TenCrop(cut_size),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
])
transform_test = torch.nn.Sequential(
transforms.TenCrop(cut_size),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
)
transforms.ToTensor的作用是将一个PIL Image格式的图片或者是取值范围为 [ 0 , 255 ] [0,255] [0,255],形状为 [ H × W × C ] [\mathrm{H} \times \mathrm{W} \times \mathrm{C}] [H×W×C]numpy.ndarray的数组转换为取值范围为 [ 0.0 , 1.0 ] [0.0,1.0] [0.0,1.0],形状为 [ C × H × W ] [\mathrm{C}\times \mathrm{H}\times \mathrm{W}] [C×H×W]的tensor格式图片。
transforms.RandomCrop的作用是在图片的随机位置上进行裁剪并返回新的图片。
CLASS torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode=‘constant’)
- size:表示裁剪图片的输出尺寸,如果参数是一个整数则裁剪的是一个正方形
- padding:表示图像每个边框上的可选填充。默认值是None
- pad_if_needed:如果图像小于所需大小,它将填充图像,以避免引发异常
- fill:表示像素填充值,默认值为0。如果元组长度为3,则用于分别填充R、G、B通道
- padding_mode:表示像素填充值的类型,默认是常值,也有边缘填充,反射和对称
transforms.RandomCrop具体实例的代码实现和对应的可视化图如下所示
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.io import read_image
plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)
def show(imgs):
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = T.ToPILImage()(img.to('cpu'))
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
plt.show()
img1 = read_image(str(Path('assets') / 'KOBE1.png'))
img2 = read_image(str(Path('assets') / 'KOBE2.png'))
show([img1, img2])
transforms = T.Compose([
T.RandomCrop(224),
])
# transforms = torch.nn.Sequential(
# T.RandomCrop(224),
# )
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img1 = img1.to(device)
img2 = img2.to(device)
transformed_img1 = transforms(img1)
transformed_img2 = transforms(img2)
show([transformed_img1, transformed_img2])
transforms.RandomHorizontalFlip的作用是以特定的概率将图片进行水平翻转。
CLASS torchvision.transforms.RandomHorizontalFlip(p=0.5)
- p:表示图片水平翻转的概率,默认值是0.5
transforms.RandomHorizontalFlip具体实例的代码实现和对应的可视化图如下所示
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.io import read_image
plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)
def show(imgs):
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = T.ToPILImage()(img.to('cpu'))
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
plt.show()
img1 = read_image(str(Path('assets') / 'KOBE1.png')) # type : torch
img2 = read_image(str(Path('assets') / 'KOBE2.png')) # type : torch
show([img1, img2])
transforms = T.Compose([
T.RandomHorizontalFlip(p=0.9),
])
# transforms = torch.nn.Sequential(
# T.RandomHorizontalFlip(p=0.3),
# )
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img1 = img1.to(device)
img2 = img2.to(device)
transformed_img1 = transforms(img1)
transformed_img2 = transforms(img2)
show([transformed_img1, transformed_img2])
transforms.RandomCrop的作用是可以将一张图片的四个角和中心进行裁剪后,然后加上返回的翻转后共10张图片,其中默认翻转是水平翻转。
CLASS torchvision.transforms.TenCrop(size, vertical_flip=False)
- size:表示裁剪图片的输出尺寸,如果参数是一个整数则裁剪的是一个正方形
- vertical_flip:表示图片是否用垂直翻转代替水平翻转None
需要注意的是transforms.TenCrop函数的输入必须是 P I L \mathrm{PIL} PIL的图片格式,具体实例的代码实现和对应的可视化图如下所示
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(0)
def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [orig_img] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
plt.show()
orig_img = Image.open(Path('assets') / 'KOBE1.png') # tyep : PIL
(top_left, top_right, bottom_left, bottom_right, center,
flip_top_left, flip_top_right, flip_bottom_left, flip_bottom_right, flip_center) = T.TenCrop(size=(200,200))(orig_img)
plot([[top_left, top_right, bottom_left, bottom_right, center],
[flip_top_left, flip_top_right, flip_bottom_left, flip_bottom_right, flip_center]],with_orig=False)
格外需要注意transforms.TenCrop对于每张图片会返回10张变换后的图片,尤其是在测试阶段会导致图片数量和标签数量不匹配,可以进行如下处理
transform = Compose([
FiveCrop(size), # this is a list of PIL Images
Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
])
#In your test loop you can do the following:
input, target = batch # input is a 5d tensor, target is 2d
bs, ncrops, c, h, w = input.size()
result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。