赞
踩
VAE 模型生成的内容质量不高,原因可能在于将图片编码成连续变量(映射为标准分布),然而将图片编码成离散变量可能会更好(因为现实生活中习惯用离散变量来形容事物,例如人的高矮胖瘦等都是离散的;)
VQVAE模型的三个关键模块:Encoder、Decoder 和 Codebook;
Encoder 将输入编码成特征向量,计算特征向量与 Codebook 中 Embedding 向量的相似性(L2距离),取最相似的 Embedding 向量作为特征向量的替代,并输入到 Decoder 中进行重构输入;
VQVAE的损失函数包括源图片和重构图片的重构损失,以及 Codebook 中量化过程的量化损失 vq_loss;
VQ-VAE详细介绍参考:轻松理解 VQ-VAE
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- class VectorQuantizer(nn.Module):
- def __init__(self, num_embeddings, embedding_dim, commitment_cost):
- super(VectorQuantizer, self).__init__()
-
- self._embedding_dim = embedding_dim
- self._num_embeddings = num_embeddings
- self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
- self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
- self._commitment_cost = commitment_cost
-
- def forward(self, inputs):
- # convert inputs from BCHW -> BHWC
- inputs = inputs.permute(0, 2, 3, 1).contiguous()
- input_shape = inputs.shape
-
- # Flatten input
- flat_input = inputs.view(-1, self._embedding_dim)
-
- # Calculate distances
- distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
- + torch.sum(self._embedding.weight**2, dim=1)
- - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
-
- # Encoding
- encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
- encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
- encodings.scatter_(1, encoding_indices, 1)
-
- # Quantize and unflatten
- quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
-
- # Loss
- e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 论文中损失函数的第三项
- q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项
- loss = q_latent_loss + self._commitment_cost * e_latent_loss
-
- quantized = inputs + (quantized - inputs).detach() # 梯度复制
- avg_probs = torch.mean(encodings, dim=0)
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
-
- # convert quantized from BHWC -> BCHW
- return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
-
- class VectorQuantizerEMA(nn.Module):
- def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
- super(VectorQuantizerEMA, self).__init__()
-
- self._embedding_dim = embedding_dim
- self._num_embeddings = num_embeddings
-
- self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
- self._embedding.weight.data.normal_()
- self._commitment_cost = commitment_cost
-
- self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
- self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
- self._ema_w.data.normal_()
-
- self._decay = decay
- self._epsilon = epsilon
-
- def forward(self, inputs):
- # convert inputs from BCHW -> BHWC
- inputs = inputs.permute(0, 2, 3, 1).contiguous()
- input_shape = inputs.shape # B(256) H(8) W(8) C(64)
-
- # Flatten input BHWC -> BHW, C
- flat_input = inputs.view(-1, self._embedding_dim)
-
- # Calculate distances 计算与embedding space中所有embedding的距离
- distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
- + torch.sum(self._embedding.weight**2, dim=1)
- - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
-
- # Encoding
- encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # 取最相似的embedding
- encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
- encodings.scatter_(1, encoding_indices, 1) # 映射为 one-hot vector
-
- # Quantize and unflatten
- quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # 根据index使用embedding space对应的embedding
-
- # Use EMA to update the embedding vectors
- if self.training:
- self._ema_cluster_size = self._ema_cluster_size * self._decay + \
- (1 - self._decay) * torch.sum(encodings, 0)
-
- # Laplace smoothing of the cluster size
- n = torch.sum(self._ema_cluster_size.data)
- self._ema_cluster_size = (
- (self._ema_cluster_size + self._epsilon)
- / (n + self._num_embeddings * self._epsilon) * n)
-
- dw = torch.matmul(encodings.t(), flat_input)
- self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
-
- self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) # 论文中公式(8)
-
- # Loss
- e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 计算encoder输出(即inputs)和decoder输入(即quantized)之间的损失
- loss = self._commitment_cost * e_latent_loss
-
- # Straight Through Estimator
- quantized = inputs + (quantized - inputs).detach() # trick, 将decoder的输入对应的梯度复制,作为encoder的输出对应的梯度
- avg_probs = torch.mean(encodings, dim=0)
- perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
-
- # convert quantized from BHWC -> BCHW
- return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
-
- class Residual(nn.Module):
- def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
- super(Residual, self).__init__()
- self._block = nn.Sequential(
- nn.ReLU(True),
- nn.Conv2d(in_channels = in_channels,
- out_channels = num_residual_hiddens,
- kernel_size = 3, stride = 1, padding = 1, bias = False),
- nn.ReLU(True),
- nn.Conv2d(in_channels = num_residual_hiddens,
- out_channels = num_hiddens,
- kernel_size = 1, stride = 1, bias = False)
- )
-
- def forward(self, x):
- return x + self._block(x)
-
- class ResidualStack(nn.Module):
- def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
- super(ResidualStack, self).__init__()
- self._num_residual_layers = num_residual_layers
- self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
- for _ in range(self._num_residual_layers)])
-
- def forward(self, x):
- for i in range(self._num_residual_layers):
- x = self._layers[i](x)
- return F.relu(x)
-
- class Encoder(nn.Module):
- def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
- super(Encoder, self).__init__()
- self._conv_1 = nn.Conv2d(in_channels = in_channels,
- out_channels = num_hiddens//2,
- kernel_size = 4,
- stride = 2, padding = 1)
- self._conv_2 = nn.Conv2d(in_channels = num_hiddens//2,
- out_channels = num_hiddens,
- kernel_size = 4,
- stride = 2, padding = 1)
- self._conv_3 = nn.Conv2d(in_channels = num_hiddens,
- out_channels = num_hiddens,
- kernel_size = 3,
- stride = 1, padding = 1)
- self._residual_stack = ResidualStack(in_channels = num_hiddens,
- num_hiddens = num_hiddens,
- num_residual_layers = num_residual_layers,
- num_residual_hiddens = num_residual_hiddens)
-
- def forward(self, inputs):
- x = self._conv_1(inputs)
- x = F.relu(x)
- x = self._conv_2(x)
- x = F.relu(x)
- x = self._conv_3(x)
- return self._residual_stack(x)
-
- class Decoder(nn.Module):
- def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
- super(Decoder, self).__init__()
-
- self._conv_1 = nn.Conv2d(in_channels=in_channels,
- out_channels=num_hiddens,
- kernel_size=3,
- stride=1, padding=1)
- self._residual_stack = ResidualStack(in_channels=num_hiddens,
- num_hiddens=num_hiddens,
- num_residual_layers=num_residual_layers,
- num_residual_hiddens=num_residual_hiddens)
- self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
- out_channels=num_hiddens//2,
- kernel_size=4,
- stride=2, padding=1)
- self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2,
- out_channels=3,
- kernel_size=4,
- stride=2, padding=1)
-
- def forward(self, inputs):
- x = self._conv_1(inputs)
-
- x = self._residual_stack(x)
-
- x = self._conv_trans_1(x)
- x = F.relu(x)
-
- return self._conv_trans_2(x)
-
- class Model(nn.Module):
- def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
- num_embeddings, embedding_dim, commitment_cost, decay=0):
- super(Model, self).__init__()
-
- self._encoder = Encoder(3, num_hiddens,
- num_residual_layers,
- num_residual_hiddens)
- self._pre_vq_conv = nn.Conv2d(in_channels = num_hiddens,
- out_channels = embedding_dim,
- kernel_size = 1,
- stride = 1)
- if decay > 0.0:
- self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
- commitment_cost, decay)
- else:
- self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
- commitment_cost)
- self._decoder = Decoder(embedding_dim,
- num_hiddens,
- num_residual_layers,
- num_residual_hiddens)
-
- def forward(self, x):
- # x.shape: B(256) C(3) H(32) W(32)
- z = self._encoder(x)
- z = self._pre_vq_conv(z)
- loss, quantized, perplexity, _ = self._vq_vae(z)
- x_recon = self._decoder(quantized) # decoder解码还原图像 B(256) C(3) H(32) W(32)
-
- return loss, x_recon, perplexity
完整代码参考:liujf69/VQ-VAE
3--部分细节解读:
重构损失计算:
计算源图像和重构图像的MSE损失
- vq_loss, data_recon, perplexity = self.model(data)
- recon_error = F.mse_loss(data_recon, data) / self.data_variance
VQ量化损失计算:
inputs表示Encoder的输出,quantized是Codebook中与 inputs 最接近的向量;
- # Loss
- e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 论文中损失函数的第三项
- q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项
- loss = q_latent_loss + self._commitment_cost * e_latent_loss
Decoder的梯度复制到Encoder中:inputs是Encoder的输出,quantized是Decoder的输入;
quantized = inputs + (quantized - inputs).detach() # 梯度复制
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。