当前位置:   article > 正文

pytorch自动编码器实现有损图像压缩_autoencoder图片压缩

autoencoder图片压缩

自动编码器(AutoEncoder)由编码器(Encoder)和解码器(Decoder)两部分组成。编码器和解码器可以是任意模型,通常神经网络模型作为编码器和解码器。

自动编码器作为一种数据压缩的方法,其原理是:输入数据经过编码器变成一个编码(code),然后将这个编码作为解码器的输入,观察解码器的输出是否能还原原始数据,因此将解码器的输出和原始数据的误差作为最优化的目标。

下面以MNIST数据集为例,使用pytorch1.0构建一个卷积神经网络做自动编码器。


1.添加引用的库文件

  1. import os
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torch.utils.data import DataLoader
  6. from torchvision import datasets, transforms
  7. from torchvision.utils import save_image

2.定义超参数,是否使用GPU加速

  1. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  2. batch_size = 512

3.加载MNIST数据集,并将图片的大小变为-1~1之间,这样可以使输入变得更对称,训练更加容易收敛。

  1. # 标准化
  2. data_tf = transforms.Compose(
  3. [transforms.ToTensor(),
  4. transforms.Normalize([0.5], [0.5])]
  5. )
  6. train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
  7. train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

4.定义卷积神经网络的自动编码器

  1. class AutoEncoder(nn.Module):
  2. def __init__(self):
  3. super(AutoEncoder, self).__init__()
  4. self.encoder = nn.Sequential(
  5. nn.Conv2d(1, 16, 3, stride=3, padding=1), # b,16,10,10
  6. nn.ReLU(True),
  7. nn.MaxPool2d(2, stride=2), # b,16,5,5
  8. nn.Conv2d(16, 8, 3, stride=2, padding=1), # b,8,3,3
  9. nn.ReLU(True),
  10. nn.MaxPool2d(2, stride=1) # b,8,2,2
  11. )
  12. self.decoder = nn.Sequential(
  13. nn.ConvTranspose2d(8, 16, 3, stride=2), # b,16,5,5
  14. nn.ReLU(True),
  15. nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b,8,15,15
  16. nn.ReLU(True),
  17. nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b,1,28,28
  18. nn.Tanh()
  19. )
  20. def forward(self, x):
  21. encode = self.encoder(x)
  22. decode = self.decoder(encode)
  23. return encode, decode

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0,groups=1, bias=True, dilation=1)

  • in_channels(int):输入数据的通道数;
  • out_channels(int):输出数据的通道数;
  • kernel_size(int or tuple):滤波器或卷积核的大小;
  • stride(int or tuple,optional) :步长;
  • padding(int or tuple, optional):四周是否进行0填充;
  • groups(int, optional) – 从输入通道到输出通道的阻塞连接数
  • bias(bool, optional) - 如果bias=True,添加偏置
  • dilation(int or tuple, optional) – 卷积核元素之间的间距

对于每一条边输入,输出的尺寸的公式如下:

output=(input-kernel\_size+2*padding)/stride+1

 

解码器使用nn.ConvTranspose2d(),可以看作卷积的反操作。具体参数如下:

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0,output_padding=0, groups=1, bias=True, dilation=1)

  • in_channels(int) – 输入信号的通道数
  • out_channels(int) – 卷积产生的通道数
  • kerner_size(int or tuple) - 卷积核的大小
  • stride(int or tuple,optional) - 卷积步长,即要将输入扩大的倍数。
  • padding(int or tuple, optional) - 输入的每一条边补充0的层数,高宽都增加2*padding
  • output_padding(int or tuple, optional) - 输出边补充0的层数,高宽都增加padding
  • groups(int, optional) – 从输入通道到输出通道的阻塞连接数
  • bias(bool, optional) - 如果bias=True,添加偏置
  • dilation(int or tuple, optional) – 卷积核元素之间的间距

对于每一条边输入,输出的尺寸的公式如下:

output=(input-1)*stride+outputpadding-2*padding+kernel\_size

5.实例化模型,定义loss函数和优化函数

  1. model = AutoEncoder().to(device)
  2. # 定义loss函数和优化方法
  3. loss_fn = nn.MSELoss()
  4. optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

6.训练并保存解码器生成的图片

  1. for t in range(40):
  2. for data in train_data:
  3. img, label = data
  4. img = img.to(device)
  5. label = label.to(device)
  6. _, output = model(img)
  7. loss = loss_fn(output, img) / img.shape[0] # 平均损失
  8. # 反向传播
  9. optimizer.zero_grad()
  10. loss.backward()
  11. optimizer.step()
  12. if (t + 1) % 5 == 0: # 每 5 次,保存一下解码的图片和原图片
  13. print('epoch: {}, Loss: {:.4f}'.format(t + 1, loss.item()))
  14. pic = to_img(output.cpu().data)
  15. if not os.path.exists('./conv_autoencoder'):
  16. os.mkdir('./conv_autoencoder')
  17. save_image(pic, './conv_autoencoder/decode_image_{}.png'.format(t + 1))
  18. save_image(img, './conv_autoencoder/raw_image_{}.png'.format(t + 1))

结果对比(左边生成图片,右边原始图片):

附上完整代码:

  1. import os
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torch.utils.data import DataLoader
  6. from torchvision import datasets, transforms
  7. from torchvision.utils import save_image
  8. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  9. batch_size = 512
  10. # 标准化
  11. data_tf = transforms.Compose(
  12. [transforms.ToTensor(),
  13. transforms.Normalize([0.5], [0.5])]
  14. )
  15. train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
  16. train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  17. def to_img(x):
  18. x = 0.5 * (x + 1.) # 将-1~1转成0-1
  19. x = x.clamp(0, 1)
  20. x = x.view(x.shape[0], 1, 28, 28)
  21. return x
  22. class AutoEncoder(nn.Module):
  23. def __init__(self):
  24. super(AutoEncoder, self).__init__()
  25. self.encoder = nn.Sequential(
  26. nn.Conv2d(1, 16, 3, stride=3, padding=1), # b,16,10,10
  27. nn.ReLU(True),
  28. nn.MaxPool2d(2, stride=2), # b,16,5,5
  29. nn.Conv2d(16, 8, 3, stride=2, padding=1), # b,8,3,3
  30. nn.ReLU(True),
  31. nn.MaxPool2d(2, stride=1) # b,8,2,2
  32. )
  33. self.decoder = nn.Sequential(
  34. # nn.ConvTranspose2d(8, 8, 3, stride=2, padding=1), # b,8,3,3
  35. # nn.ReLU(True),
  36. # nn.ConvTranspose2d(8, 16, 4, stride=4, padding=1), # b,16,10,10
  37. # nn.ReLU(True),
  38. # nn.ConvTranspose2d(16, 1, 3, stride=3, padding=1), # b,1,28,28
  39. # nn.Tanh()
  40. nn.ConvTranspose2d(8, 16, 3, stride=2), # b,16,5,5
  41. nn.ReLU(True),
  42. nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # b,8,15,15
  43. nn.ReLU(True),
  44. nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # b,1,28,28
  45. nn.Tanh()
  46. )
  47. def forward(self, x):
  48. encode = self.encoder(x)
  49. decode = self.decoder(encode)
  50. return encode, decode
  51. model = AutoEncoder().to(device)
  52. # 定义loss函数和优化方法
  53. loss_fn = nn.MSELoss()
  54. optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
  55. for t in range(40):
  56. for data in train_data:
  57. img, label = data
  58. img = img.to(device)
  59. label = label.to(device)
  60. _, output = model(img)
  61. loss = loss_fn(output, img) / img.shape[0] # 平均损失
  62. # 反向传播
  63. optimizer.zero_grad()
  64. loss.backward()
  65. optimizer.step()
  66. if (t + 1) % 5 == 0: # 每 5 次,保存一下解码的图片和原图片
  67. print('epoch: {}, Loss: {:.4f}'.format(t + 1, loss.item()))
  68. pic = to_img(output.cpu().data)
  69. if not os.path.exists('./conv_autoencoder'):
  70. os.mkdir('./conv_autoencoder')
  71. save_image(pic, './conv_autoencoder/decode_image_{}.png'.format(t + 1))
  72. save_image(img, './conv_autoencoder/raw_image_{}.png'.format(t + 1))

 

 

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/852356
推荐阅读
相关标签
  

闽ICP备14008679号