赞
踩
[九月]Deepfake-FFDI-plot_transforms_illustrations: https://www.kaggle.com/code/chg0901/deepfake-ffdi-plot-transforms-illustrations
让我们逐部分代码进行详细解释:
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
import matplotlib.pyplot as plt
:导入Matplotlib库的pyplot模块,用于绘制图形。import torch
:导入PyTorch库,主要用于深度学习和张量操作。from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
:导入两个工具函数,用于在图像上绘制边界框和分割掩码。from torchvision import tv_tensors
:导入TorchVision中的tv_tensors模块,用于处理图像数据。from torchvision.transforms.v2 import functional as F
:导入TorchVision v2版本中的functional模块并重命名为F,用于图像变换的功能函数。def plot(imgs, row_title=None, **imshow_kwargs): if not isinstance(imgs[0], list): # Make a 2d grid even if there's just 1 row imgs = [imgs] num_rows = len(imgs) num_cols = len(imgs[0]) _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False) for row_idx, row in enumerate(imgs): for col_idx, img in enumerate(row): boxes = None masks = None if isinstance(img, tuple): img, target = img if isinstance(target, dict): boxes = target.get("boxes") masks = target.get("masks") elif isinstance(target, tv_tensors.BoundingBoxes): boxes = target else: raise ValueError(f"Unexpected target type: {type(target)}") img = F.to_image(img) if img.dtype.is_floating_point and img.min() < 0: # Poor man's re-normalization for the colors to be OK-ish. This # is useful for images coming out of Normalize() img -= img.min() img /= img.max() img = F.to_dtype(img, torch.uint8, scale=True) if boxes is not None: img = draw_bounding_boxes(img, boxes, colors="yellow", width=3) if masks is not None: img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65) ax = axs[row_idx, col_idx] ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs) ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) if row_title is not None: for row_idx in range(num_rows): axs[row_idx, 0].set(ylabel=row_title[row_idx]) plt.tight_layout()
这个函数用于绘制图像,并在图像上绘制边界框和分割掩码。
检查输入图像格式:
if not isinstance(imgs[0], list):
imgs = [imgs]
确保输入是二维网格,即使只有一行图像。
创建子图:
num_rows = len(imgs)
num_cols = len(imgs[0])
_, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
根据图像的行数和列数创建相应数量的子图。
处理每个图像:
for row_idx, row in enumerate(imgs):
for col_idx, img in enumerate(row):
...
遍历每个图像,进行处理。
提取目标信息:
if isinstance(img, tuple):
img, target = img
...
检查图像是否包含目标信息(如边界框和分割掩码)。
转换图像格式:
img = F.to_image(img)
if img.dtype.is_floating_point and img.min() < 0:
img -= img.min()
img /= img.max()
img = F.to_dtype(img, torch.uint8, scale=True)
将图像转换为适合显示的格式。
绘制边界框和分割掩码:
if boxes is not None:
img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
if masks is not None:
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
显示图像:
ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
设置行标题(如果有):
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
调整布局:
plt.tight_layout()
!wget https://mirror.coggle.club/image/tyler-swift.jpg
orig_img = Image.open('/kaggle/working/tyler-swift.jpg')
plt.axis('off')
plt.imshow(orig_img)
plt.show()
!wget ...
:下载图像。orig_img = Image.open(...)
:打开下载的图像。plt.axis('off')
:关闭坐标轴显示。plt.imshow(orig_img)
:显示图像。plt.show()
:展示图像。padded_imgs = [v2.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot([orig_img] + padded_imgs)
v2.Pad(padding=padding)
:使用不同的填充大小对图像进行填充。plot([orig_img] + padded_imgs)
:显示原始图像和填充后的图像。其他变换如 Resize
, CenterCrop
, FiveCrop
, RandomPerspective
, RandomRotation
等类似,分别使用不同的变换方法处理图像并展示结果。
gray_img = v2.Grayscale()(orig_img)
plot([orig_img, gray_img], cmap='gray')
v2.Grayscale()
:将图像转换为灰度图。plot([orig_img, gray_img], cmap='gray')
:显示原始图像和灰度图。其他变换如 ColorJitter
, GaussianBlur
, RandomInvert
等类似,分别使用不同的光度变换方法处理图像并展示结果。
policies = [v2.AutoAugmentPolicy.CIFAR10, v2.AutoAugmentPolicy.IMAGENET, v2.AutoAugmentPolicy.SVHN]
augmenters = [v2.AutoAugment(policy) for policy in policies]
imgs = [
[augmenter(orig_img) for _ in range(4)]
for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot([[orig_img] + row for row in imgs], row_title=row_title)
v2.AutoAugmentPolicy
:使用不同的自动增强策略。v2.AutoAugment(policy)
:基于策略自动增强图像。plot(...)
:展示增强后的图像。其他增强变换如 RandAugment
, TrivialAugmentWide
, AugMix
等类似,分别使用不同的增强方法处理图像并展示结果。
hflipper = v2.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot([orig_img] + transformed_imgs)
v2.RandomHorizontalFlip(p=0.5)
:以50%的概率水平翻转图像。plot([orig_img] + transformed_imgs)
:显示原始图像和翻转后的图像。其他随机变换如 RandomVerticalFlip
, RandomApply
等类似,分别使用不同的随机方法处理图像并展示结果。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。