当前位置:   article > 正文

【深度学习】13个Pytorch 图像增强方法总结

深度学习 图像增强的方法有哪些

转载自 | 极市平台

作者 | 结发授长生

来源丨https://zhuanlan.zhihu.com/p/559887437

使用数据增强技术可以增加数据集中图像的多样性,从而提高模型的性能和泛化能力。主要的图像增强技术包括:

  • 调整大小

  • 灰度变换

  • 标准化

  • 随机旋转

  • 中心裁剪

  • 随机裁剪

  • 高斯模糊

  • 亮度、对比度调节

  • 水平翻转

  • 垂直翻转

  • 高斯噪声

  • 随机块

  • 中心区域

1

『调整大小』

在开始图像大小的调整之前我们需要导入数据(图像以眼底图像为例)。

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/000001.tif'))
  11. torch.manual_seed(0) # 设置 CPU 生成随机数的 种子 ,方便下次复现实验结果
  12. print(np.asarray(orig_img).shape) #(800, 800, 3)
  13. #图像大小的调整
  14. resized_imgs = [T.Resize(size=size)(orig_img) for size in [128,256]]
  15. # plt.figure('resize:128*128')
  16. ax1 = plt.subplot(131)
  17. ax1.set_title('original')
  18. ax1.imshow(orig_img)
  19. ax2 = plt.subplot(132)
  20. ax2.set_title('resize:128*128')
  21. ax2.imshow(resized_imgs[0])
  22. ax3 = plt.subplot(133)
  23. ax3.set_title('resize:256*256')
  24. ax3.imshow(resized_imgs[1])
  25. plt.show()
9ecbdf4861ca6fa7c1bd3179043e386e.jpeg

2

『灰度变换』

此操作将RGB图像转化为灰度图像。

  1. gray_img = T.Grayscale()(orig_img)
  2. # plt.figure('resize:128*128')
  3. ax1 = plt.subplot(121)
  4. ax1.set_title('original')
  5. ax1.imshow(orig_img)
  6. ax2 = plt.subplot(122)
  7. ax2.set_title('gray')
  8. ax2.imshow(gray_img,cmap='gray')
204ae40ffba5588fe91868570c1d22a0.jpeg

3

『标准化』

标准化可以加快基于神经网络结构的模型的计算速度,加快学习速度。

  • 从每个输入通道中减去通道平均值

  • 将其除以通道标准差。

  1. normalized_img = T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(T.ToTensor()(orig_img))
  2. normalized_img = [T.ToPILImage()(normalized_img)]
  3. # plt.figure('resize:128*128')
  4. ax1 = plt.subplot(121)
  5. ax1.set_title('original')
  6. ax1.imshow(orig_img)
  7. ax2 = plt.subplot(122)
  8. ax2.set_title('normalize')
  9. ax2.imshow(normalized_img[0])
  10. plt.show()
bf269accb2669eeb874714034456b335.jpeg

4

『随机旋转』

设计角度旋转图像

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. rotated_imgs = [T.RandomRotation(degrees=90)(orig_img)]
  12. print(rotated_imgs)
  13. plt.figure('resize:128*128')
  14. ax1 = plt.subplot(121)
  15. ax1.set_title('original')
  16. ax1.imshow(orig_img)
  17. ax2 = plt.subplot(122)
  18. ax2.set_title('90°')
  19. ax2.imshow(np.array(rotated_imgs[0]))
e6a8cc81730580e2d18bc157e9e56fa9.jpeg

5

『中心剪切』

剪切图像的中心区域

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. center_crops = [T.CenterCrop(size=size)(orig_img) for size in (128,64)]
  12. plt.figure('resize:128*128')
  13. ax1 = plt.subplot(131)
  14. ax1.set_title('original')
  15. ax1.imshow(orig_img)
  16. ax2 = plt.subplot(132)
  17. ax2.set_title('128*128°')
  18. ax2.imshow(np.array(center_crops[0]))
  19. ax3 = plt.subplot(133)
  20. ax3.set_title('64*64')
  21. ax3.imshow(np.array(center_crops[1]))
  22. plt.show()
2381214042aae6df901ca2ed17700e21.jpeg

6

『随机裁剪』

随机剪切图像的某一部分

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. random_crops = [T.RandomCrop(size=size)(orig_img) for size in (400,300)]
  12. plt.figure('resize:128*128')
  13. ax1 = plt.subplot(131)
  14. ax1.set_title('original')
  15. ax1.imshow(orig_img)
  16. ax2 = plt.subplot(132)
  17. ax2.set_title('400*400')
  18. ax2.imshow(np.array(random_crops[0]))
  19. ax3 = plt.subplot(133)
  20. ax3.set_title('300*300')
  21. ax3.imshow(np.array(random_crops[1]))
  22. plt.show()
25293e61761daa652944bce333b200ba.jpeg

7

『高斯模糊』

使用高斯核对图像进行模糊变换

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. blurred_imgs = [T.GaussianBlur(kernel_size=(3, 3), sigma=sigma)(orig_img) for sigma in (3,7)]
  12. plt.figure('resize:128*128')
  13. ax1 = plt.subplot(131)
  14. ax1.set_title('original')
  15. ax1.imshow(orig_img)
  16. ax2 = plt.subplot(132)
  17. ax2.set_title('sigma=3')
  18. ax2.imshow(np.array(blurred_imgs[0]))
  19. ax3 = plt.subplot(133)
  20. ax3.set_title('sigma=7')
  21. ax3.imshow(np.array(blurred_imgs[1]))
  22. plt.show()
c14721eda02ce2ec2215fb454c58e30f.jpeg

8

『亮度、对比度和饱和度调节』

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. # random_crops = [T.RandomCrop(size=size)(orig_img) for size in (832,704, 256)]
  12. colorjitter_img = [T.ColorJitter(brightness=(2,2), contrast=(0.5,0.5), saturation=(0.5,0.5))(orig_img)]
  13. plt.figure('resize:128*128')
  14. ax1 = plt.subplot(121)
  15. ax1.set_title('original')
  16. ax1.imshow(orig_img)
  17. ax2 = plt.subplot(122)
  18. ax2.set_title('colorjitter_img')
  19. ax2.imshow(np.array(colorjitter_img[0]))
  20. plt.show()
3447f3f40881a649db9f749982a61444.jpeg

9

『水平翻转』

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. HorizontalFlip_img = [T.RandomHorizontalFlip(p=1)(orig_img)]
  12. plt.figure('resize:128*128')
  13. ax1 = plt.subplot(121)
  14. ax1.set_title('original')
  15. ax1.imshow(orig_img)
  16. ax2 = plt.subplot(122)
  17. ax2.set_title('colorjitter_img')
  18. ax2.imshow(np.array(HorizontalFlip_img[0]))
  19. plt.show()
69f4e3e6a8d27815b25a0862c5fbf8d8.jpeg

10

『垂直翻转』

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. VerticalFlip_img = [T.RandomVerticalFlip(p=1)(orig_img)]
  12. plt.figure('resize:128*128')
  13. ax1 = plt.subplot(121)
  14. ax1.set_title('original')
  15. ax1.imshow(orig_img)
  16. ax2 = plt.subplot(122)
  17. ax2.set_title('VerticalFlip')
  18. ax2.imshow(np.array(VerticalFlip_img[0]))
  19. # ax3 = plt.subplot(133)
  20. # ax3.set_title('sigma=7')
  21. # ax3.imshow(np.array(blurred_imgs[1]))
  22. plt.show()
cefa54ed9f8f73375b45034762684839.jpeg

11

『高斯噪声』

向图像中加入高斯噪声。通过设置噪声因子,噪声因子越高,图像的噪声越大。

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. def add_noise(inputs, noise_factor=0.3):
  12. noisy = inputs + torch.randn_like(inputs) * noise_factor
  13. noisy = torch.clip(noisy, 0., 1.)
  14. return noisy
  15. noise_imgs = [add_noise(T.ToTensor()(orig_img), noise_factor) for noise_factor in (0.3, 0.6)]
  16. noise_imgs = [T.ToPILImage()(noise_img) for noise_img in noise_imgs]
  17. plt.figure('resize:128*128')
  18. ax1 = plt.subplot(131)
  19. ax1.set_title('original')
  20. ax1.imshow(orig_img)
  21. ax2 = plt.subplot(132)
  22. ax2.set_title('noise_factor=0.3')
  23. ax2.imshow(np.array(noise_imgs[0]))
  24. ax3 = plt.subplot(133)
  25. ax3.set_title('noise_factor=0.6')
  26. ax3.imshow(np.array(noise_imgs[1]))
  27. plt.show()
23efc2f0bfa108a0d25a1fa1a4115d1a.jpeg

12

『随机块』

正方形补丁随机应用在图像中。这些补丁的数量越多,神经网络解决问题的难度就越大。

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. def add_random_boxes(img,n_k,size=64):
  12. h,w = size,size
  13. img = np.asarray(img).copy()
  14. img_size = img.shape[1]
  15. boxes = []
  16. for k in range(n_k):
  17. y,x = np.random.randint(0,img_size-w,(2,))
  18. img[y:y+h,x:x+w] = 0
  19. boxes.append((x,y,h,w))
  20. img = Image.fromarray(img.astype('uint8'), 'RGB')
  21. return img
  22. blocks_imgs = [add_random_boxes(orig_img,n_k=10)]
  23. plt.figure('resize:128*128')
  24. ax1 = plt.subplot(131)
  25. ax1.set_title('original')
  26. ax1.imshow(orig_img)
  27. ax2 = plt.subplot(132)
  28. ax2.set_title('10 black boxes')
  29. ax2.imshow(np.array(blocks_imgs[0]))
  30. plt.show()
3fd4cc69d428d4cbee9520d03b482475.jpeg

13

『中心区域』

和随机块类似,只不过在图像的中心加入补丁

  1. from PIL import Image
  2. from pathlib import Path
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import sys
  6. import torch
  7. import numpy as np
  8. import torchvision.transforms as T
  9. plt.rcParams["savefig.bbox"] = 'tight'
  10. orig_img = Image.open(Path('image/2.png'))
  11. def add_central_region(img, size=32):
  12. h, w = size, size
  13. img = np.asarray(img).copy()
  14. img_size = img.shape[1]
  15. img[int(img_size / 2 - h):int(img_size / 2 + h), int(img_size / 2 - w):int(img_size / 2 + w)] = 0
  16. img = Image.fromarray(img.astype('uint8'), 'RGB')
  17. return img
  18. central_imgs = [add_central_region(orig_img, size=128)]
  19. plt.figure('resize:128*128')
  20. ax1 = plt.subplot(131)
  21. ax1.set_title('original')
  22. ax1.imshow(orig_img)
  23. ax2 = plt.subplot(132)
  24. ax2.set_title('')
  25. ax2.imshow(np.array(central_imgs[0]))
  26. #
  27. # ax3 = plt.subplot(133)
  28. # ax3.set_title('20 black boxes')
  29. # ax3.imshow(np.array(blocks_imgs[1]))
  30. plt.show()
053d0e6fe7ae2ecc5af06e410df91d35.jpeg
 
 

96c32bc3aa3b67bb81ea829428c4cfde.jpeg

 
 
 
 
 
 
 
 
  1. 往期精彩回顾
  2. 适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑机器学习交流qq群955171419,加入微信群请扫码

578e03499bfe8efa40361b128d9b16ee.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/78333?site
推荐阅读
相关标签
  

闽ICP备14008679号