当前位置:   article > 正文

【Datawhale 】Datawhale AI 夏令营-Task3笔记

【Datawhale 】Datawhale AI 夏令营-Task3笔记

学习- 九月助教老师的图像数据增强方法实操代码

[九月]Deepfake-FFDI-plot_transforms_illustrations: https://www.kaggle.com/code/chg0901/deepfake-ffdi-plot-transforms-illustrations
让我们逐部分代码进行详细解释:

导入库

import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
  • 1
  • 2
  • 3
  • 4
  • 5
  • import matplotlib.pyplot as plt:导入Matplotlib库的pyplot模块,用于绘制图形。
  • import torch:导入PyTorch库,主要用于深度学习和张量操作。
  • from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks:导入两个工具函数,用于在图像上绘制边界框和分割掩码。
  • from torchvision import tv_tensors:导入TorchVision中的tv_tensors模块,用于处理图像数据。
  • from torchvision.transforms.v2 import functional as F:导入TorchVision v2版本中的functional模块并重命名为F,用于图像变换的功能函数。

定义绘图函数

def plot(imgs, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            boxes = None
            masks = None
            if isinstance(img, tuple):
                img, target = img
                if isinstance(target, dict):
                    boxes = target.get("boxes")
                    masks = target.get("masks")
                elif isinstance(target, tv_tensors.BoundingBoxes):
                    boxes = target
                else:
                    raise ValueError(f"Unexpected target type: {type(target)}")
            img = F.to_image(img)
            if img.dtype.is_floating_point and img.min() < 0:
                # Poor man's re-normalization for the colors to be OK-ish. This
                # is useful for images coming out of Normalize()
                img -= img.min()
                img /= img.max()

            img = F.to_dtype(img, torch.uint8, scale=True)
            if boxes is not None:
                img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
            if masks is not None:
                img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

            ax = axs[row_idx, col_idx]
            ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

这个函数用于绘制图像,并在图像上绘制边界框和分割掩码。

  1. 检查输入图像格式

    if not isinstance(imgs[0], list):
        imgs = [imgs]
    
    • 1
    • 2

    确保输入是二维网格,即使只有一行图像。

  2. 创建子图

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    
    • 1
    • 2
    • 3

    根据图像的行数和列数创建相应数量的子图。

  3. 处理每个图像

    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            ...
    
    • 1
    • 2
    • 3

    遍历每个图像,进行处理。

  4. 提取目标信息

    if isinstance(img, tuple):
        img, target = img
        ...
    
    • 1
    • 2
    • 3

    检查图像是否包含目标信息(如边界框和分割掩码)。

  5. 转换图像格式

    img = F.to_image(img)
    if img.dtype.is_floating_point and img.min() < 0:
        img -= img.min()
        img /= img.max()
    img = F.to_dtype(img, torch.uint8, scale=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    将图像转换为适合显示的格式。

  6. 绘制边界框和分割掩码

    if boxes is not None:
        img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
    if masks is not None:
        img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
    
    • 1
    • 2
    • 3
    • 4
  7. 显示图像

    ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    
    • 1
    • 2
  8. 设置行标题(如果有):

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])
    
    • 1
    • 2
    • 3
  9. 调整布局

    plt.tight_layout()
    
    • 1

下载并显示图像

!wget https://mirror.coggle.club/image/tyler-swift.jpg

orig_img = Image.open('/kaggle/working/tyler-swift.jpg')
plt.axis('off')
plt.imshow(orig_img)
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • !wget ...:下载图像。
  • orig_img = Image.open(...):打开下载的图像。
  • plt.axis('off'):关闭坐标轴显示。
  • plt.imshow(orig_img):显示图像。
  • plt.show():展示图像。

几何变换示例

padded_imgs = [v2.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot([orig_img] + padded_imgs)
  • 1
  • 2
  • v2.Pad(padding=padding):使用不同的填充大小对图像进行填充。
  • plot([orig_img] + padded_imgs):显示原始图像和填充后的图像。

其他变换如 Resize, CenterCrop, FiveCrop, RandomPerspective, RandomRotation 等类似,分别使用不同的变换方法处理图像并展示结果。

光度变换示例

gray_img = v2.Grayscale()(orig_img)
plot([orig_img, gray_img], cmap='gray')
  • 1
  • 2
  • v2.Grayscale():将图像转换为灰度图。
  • plot([orig_img, gray_img], cmap='gray'):显示原始图像和灰度图。

其他变换如 ColorJitter, GaussianBlur, RandomInvert 等类似,分别使用不同的光度变换方法处理图像并展示结果。

增强变换示例

policies = [v2.AutoAugmentPolicy.CIFAR10, v2.AutoAugmentPolicy.IMAGENET, v2.AutoAugmentPolicy.SVHN]
augmenters = [v2.AutoAugment(policy) for policy in policies]
imgs = [
    [augmenter(orig_img) for _ in range(4)]
    for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot([[orig_img] + row for row in imgs], row_title=row_title)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • v2.AutoAugmentPolicy:使用不同的自动增强策略。
  • v2.AutoAugment(policy):基于策略自动增强图像。
  • plot(...):展示增强后的图像。

其他增强变换如 RandAugment, TrivialAugmentWide, AugMix 等类似,分别使用不同的增强方法处理图像并展示结果。

随机应用变换示例

hflipper = v2.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot([orig_img] + transformed_imgs)
  • 1
  • 2
  • 3
  • v2.RandomHorizontalFlip(p=0.5):以50%的概率水平翻转图像。
  • plot([orig_img] + transformed_imgs):显示原始图像和翻转后的图像。

其他随机变换如 RandomVerticalFlip, RandomApply 等类似,分别使用不同的随机方法处理图像并展示结果。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/963665
推荐阅读
相关标签
  

闽ICP备14008679号