赞
踩
之前总结了一篇VAE的,这次来个它的离散版本。
VAE(Variational Autoencoder)简单记录
论文: Neural Discrete Representation Learning
代码: https://gitee.com/mirrors_ritheshkumar95/pytorch-vqvae
v2: Generating Diverse High-Fidelity Images with VQ-VAE-2
这里选取 生成模型之VQ-VAE 的代码。之所以粘贴过来是因为我想写一些笔记啥的,只用于学习用途…
class VectorQuantizer(nn.Module): """ VQ-VAE layer: Input any tensor to be quantized. Args: embedding_dim (int): the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well. num_embeddings (int): the number of vectors in the quantized space. commitment_cost (float): scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta). """ def __init__(self, embedding_dim, num_embeddings, commitment_cost): super().__init__() self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.commitment_cost = commitment_cost # initialize embeddings self.embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim) def forward(self, x): # [B, C, H, W] -> [B, H, W, C] x = x.permute(0, 2, 3, 1).contiguous() # [B, H, W, C] -> [BHW, C] flat_x = x.reshape(-1, self.embedding_dim) encoding_indices = self.get_code_indices(flat_x) quantized = self.quantize(encoding_indices) quantized = quantized.view_as(x) # [B, H, W, C] if not self.training: quantized = quantized.permute(0, 3, 1, 2).contiguous() return quantized # embedding loss: move the embeddings towards the encoder's output q_latent_loss = F.mse_loss(quantized, x.detach()) # commitment loss e_latent_loss = F.mse_loss(x, quantized.detach()) loss = q_latent_loss + self.commitment_cost * e_latent_loss # Straight Through Estimator quantized = x + (quantized - x).detach() quantized = quantized.permute(0, 3, 1, 2).contiguous() return quantized, loss def get_code_indices(self, flat_x): # compute L2 distance distances = ( torch.sum(flat_x ** 2, dim=1, keepdim=True) + torch.sum(self.embeddings.weight ** 2, dim=1) - 2. * torch.matmul(flat_x, self.embeddings.weight.t()) ) # [N, M] encoding_indices = torch.argmin(distances, dim=1) # [N,] return encoding_indices def quantize(self, encoding_indices): """Returns embedding tensor for a batch of indices.""" return self.embeddings(encoding_indices)
class Encoder(nn.Module): """Encoder of VQ-VAE""" def __init__(self, in_dim=3, latent_dim=16): super().__init__() self.in_dim = in_dim self.latent_dim = latent_dim self.convs = nn.Sequential( nn.Conv2d(in_dim, 32, 3, stride=2, padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, latent_dim, 1), ) def forward(self, x): return self.convs(x)
class Decoder(nn.Module): """Decoder of VQ-VAE""" def __init__(self, out_dim=1, latent_dim=16): super().__init__() self.out_dim = out_dim self.latent_dim = latent_dim self.convs = nn.Sequential( nn.ConvTranspose2d(latent_dim, 64, 3, stride=2, padding=1, output_padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(32, out_dim, 3, padding=1), ) def forward(self, x): return self.convs(x)
class VQVAE(nn.Module): """VQ-VAE""" def __init__(self, in_dim, embedding_dim, num_embeddings, data_variance, commitment_cost=0.25): super().__init__() self.in_dim = in_dim self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.data_variance = data_variance self.encoder = Encoder(in_dim, embedding_dim) self.vq_layer = VectorQuantizer(embedding_dim, num_embeddings, commitment_cost) self.decoder = Decoder(in_dim, embedding_dim) def forward(self, x): z = self.encoder(x) if not self.training: e = self.vq_layer(z) x_recon = self.decoder(e) return e, x_recon e, e_q_loss = self.vq_layer(z) x_recon = self.decoder(e) recon_loss = F.mse_loss(x_recon, x) / self.data_variance return e_q_loss + recon_loss
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。