赞
踩
- import matplotlib.pyplot as plt
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torchvision
-
-
- # 在一个类中编写编码器和解码器层。为编码器和解码器层的组件都定义了全连接层
- class AE(nn.Module):
- def __init__(self, **kwargs):
- super().__init__()
- self.encoder_hidden_layer = nn.Linear(
- in_features=kwargs["input_shape"], out_features=128
- ) # 编码器隐藏层
- self.encoder_output_layer = nn.Linear(
- in_features=128, out_features=128
- ) # 编码器输出层
- self.decoder_hidden_layer = nn.Linear(
- in_features=128, out_features=128
- ) # 解码器隐藏层
- self.decoder_output_layer = nn.Linear(
- in_features=128, out_features=kwargs["input_shape"]
- ) # 解码器输出层
-
- # 定义了模型的前向传播过程,包括激活函数的应用和重构图像的生成
- def forward(self, features):
- activation = self.encoder_hidden_layer(features)
- activation = torch.relu(activation) # ReLU 激活函数,得到编码器的激活值
- code = self.encoder_output_layer(activation)
- code = torch.sigmoid(code) # Sigmoid 激活函数,以确保编码后的表示在 [0, 1] 范围内
- activation = self.decoder_hidden_layer(code)
- activation = torch.relu(activation)
- activation = self.decoder_output_layer(activation)
- reconstructed = torch.sigmoid(activation)
- return reconstructed
-
-
- if __name__ == '__main__':
- # 设置批大小、学习周期和学习率
- batch_size = 512
- epochs = 30
- learning_rate = 1e-3
-
- # 载入 MNIST 数据集中的图片进行训练
- transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) # 将图像转换为张量
-
- train_dataset = torchvision.datasets.MNIST(
- root="~/torch_datasets", train=True, transform=transform, download=True
- ) # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 True
-
- train_loader = torch.utils.data.DataLoader(
- train_dataset, batch_size=batch_size, shuffle=True
- ) # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据
-
- # 在使用定义的 AE 类之前,有以下事情要做:
- # 配置要在哪个设备上运行
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # 建立 AE 模型并载入到 CPU 设备
- model = AE(input_shape=784).to(device)
-
- # Adam 优化器,学习率 10e-3
- optimizer = optim.Adam(model.parameters(), lr=learning_rate)
-
- # 使用均方误差(MSE)损失函数
- criterion = nn.MSELoss()
-
- # 在GPU设备上运行,实例化一个输入大小为784的AE自编码器,并用Adam作为训练优化器用MSELoss作为损失函数
- # 训练:
- for epoch in range(epochs):
- loss = 0
- for batch_features, _ in train_loader:
- # 将小批数据变形为 [N, 784] 矩阵,并加载到 CPU 设备
- batch_features = batch_features.view(-1, 784).to(device)
-
- # 梯度设置为 0,因为 torch 会累加梯度
- optimizer.zero_grad()
-
- # 计算重构
- outputs = model(batch_features)
-
- # 计算训练重建损失
- train_loss = criterion(outputs, batch_features)
-
- # 计算累积梯度
- train_loss.backward()
-
- # 根据当前梯度更新参数
- optimizer.step()
-
- # 将小批量训练损失加到周期损失中
- loss += train_loss.item()
-
- # 计算每个周期的训练损失
- loss = loss / len(train_loader)
-
- # 显示每个周期的训练损失
- print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
-
- # 用训练过的自编码器提取一些测试用例来重构
- test_dataset = torchvision.datasets.MNIST(
- root="~/torch_datasets", train=False, transform=transform, download=True
- ) # 加载 MNIST 测试数据集
-
- test_loader = torch.utils.data.DataLoader(
- test_dataset, batch_size=10, shuffle=False
- ) # 创建一个测试数据加载器
-
- test_examples = None
-
- # 通过循环遍历测试数据加载器,获取一个批次的图像数据
- with torch.no_grad(): # 使用 torch.no_grad() 上下文管理器,确保在该上下文中不会进行梯度计算
- for batch_features in test_loader: # 历测试数据加载器中的每个批次的图像数据
- batch_features = batch_features[0] # 获取当前批次的图像数据
- test_examples = batch_features.view(-1, 784).to(
- device) # 将当前批次的图像数据转换为大小为 (批大小, 784) 的张量,并加载到指定的设备(CPU 或 GPU)上
- reconstruction = model(test_examples) # 使用训练好的自编码器模型对测试数据进行重构,即生成重构的图像
- break
-
- # 试着用训练过的自编码器重建一些测试图像
- with torch.no_grad():
- number = 10 # 设置要显示的图像数量
- plt.figure(figsize=(20, 4)) # 创建一个新的 Matplotlib 图形,设置图形大小为 (20, 4)
- for index in range(number): # 遍历要显示的图像数量
- # 显示原始图
- ax = plt.subplot(2, number, index + 1)
- plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))
- plt.gray()
- ax.get_xaxis().set_visible(False)
- ax.get_yaxis().set_visible(False)
-
- # 显示重构图
- ax = plt.subplot(2, number, index + 1 + number)
- plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))
- plt.gray()
- ax.get_xaxis().set_visible(False)
- ax.get_yaxis().set_visible(False)
- plt.savefig('reconstruction_results.png') # 保存图像
- plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。