当前位置:   article > 正文

AIGC笔记--VQVAE模型搭建_vqvae代码

vqvae代码

1--VQVAE模型

        VAE 模型生成的内容质量不高,原因可能在于将图片编码成连续变量(映射为标准分布),然而将图片编码成离散变量可能会更好(因为现实生活中习惯用离散变量来形容事物,例如人的高矮胖瘦等都是离散的;)

        VQVAE模型的三个关键模块:EncoderDecoderCodebook

        Encoder 将输入编码成特征向量,计算特征向量与 Codebook 中 Embedding 向量的相似性(L2距离),取最相似的 Embedding 向量作为特征向量的替代,并输入到 Decoder 中进行重构输入;

        VQVAE的损失函数包括源图片和重构图片的重构损失,以及 Codebook 中量化过程的量化损失 vq_loss;

        VQ-VAE详细介绍参考:轻松理解 VQ-VAE

2--简单代码实例

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class VectorQuantizer(nn.Module):
  5. def __init__(self, num_embeddings, embedding_dim, commitment_cost):
  6. super(VectorQuantizer, self).__init__()
  7. self._embedding_dim = embedding_dim
  8. self._num_embeddings = num_embeddings
  9. self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
  10. self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
  11. self._commitment_cost = commitment_cost
  12. def forward(self, inputs):
  13. # convert inputs from BCHW -> BHWC
  14. inputs = inputs.permute(0, 2, 3, 1).contiguous()
  15. input_shape = inputs.shape
  16. # Flatten input
  17. flat_input = inputs.view(-1, self._embedding_dim)
  18. # Calculate distances
  19. distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
  20. + torch.sum(self._embedding.weight**2, dim=1)
  21. - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
  22. # Encoding
  23. encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
  24. encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
  25. encodings.scatter_(1, encoding_indices, 1)
  26. # Quantize and unflatten
  27. quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
  28. # Loss
  29. e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 论文中损失函数的第三项
  30. q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项
  31. loss = q_latent_loss + self._commitment_cost * e_latent_loss
  32. quantized = inputs + (quantized - inputs).detach() # 梯度复制
  33. avg_probs = torch.mean(encodings, dim=0)
  34. perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
  35. # convert quantized from BHWC -> BCHW
  36. return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
  37. class VectorQuantizerEMA(nn.Module):
  38. def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
  39. super(VectorQuantizerEMA, self).__init__()
  40. self._embedding_dim = embedding_dim
  41. self._num_embeddings = num_embeddings
  42. self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
  43. self._embedding.weight.data.normal_()
  44. self._commitment_cost = commitment_cost
  45. self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
  46. self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
  47. self._ema_w.data.normal_()
  48. self._decay = decay
  49. self._epsilon = epsilon
  50. def forward(self, inputs):
  51. # convert inputs from BCHW -> BHWC
  52. inputs = inputs.permute(0, 2, 3, 1).contiguous()
  53. input_shape = inputs.shape # B(256) H(8) W(8) C(64)
  54. # Flatten input BHWC -> BHW, C
  55. flat_input = inputs.view(-1, self._embedding_dim)
  56. # Calculate distances 计算与embedding space中所有embedding的距离
  57. distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
  58. + torch.sum(self._embedding.weight**2, dim=1)
  59. - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
  60. # Encoding
  61. encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # 取最相似的embedding
  62. encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
  63. encodings.scatter_(1, encoding_indices, 1) # 映射为 one-hot vector
  64. # Quantize and unflatten
  65. quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # 根据index使用embedding space对应的embedding
  66. # Use EMA to update the embedding vectors
  67. if self.training:
  68. self._ema_cluster_size = self._ema_cluster_size * self._decay + \
  69. (1 - self._decay) * torch.sum(encodings, 0)
  70. # Laplace smoothing of the cluster size
  71. n = torch.sum(self._ema_cluster_size.data)
  72. self._ema_cluster_size = (
  73. (self._ema_cluster_size + self._epsilon)
  74. / (n + self._num_embeddings * self._epsilon) * n)
  75. dw = torch.matmul(encodings.t(), flat_input)
  76. self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
  77. self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) # 论文中公式(8)
  78. # Loss
  79. e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 计算encoder输出(即inputs)和decoder输入(即quantized)之间的损失
  80. loss = self._commitment_cost * e_latent_loss
  81. # Straight Through Estimator
  82. quantized = inputs + (quantized - inputs).detach() # trick, 将decoder的输入对应的梯度复制,作为encoder的输出对应的梯度
  83. avg_probs = torch.mean(encodings, dim=0)
  84. perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
  85. # convert quantized from BHWC -> BCHW
  86. return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
  87. class Residual(nn.Module):
  88. def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
  89. super(Residual, self).__init__()
  90. self._block = nn.Sequential(
  91. nn.ReLU(True),
  92. nn.Conv2d(in_channels = in_channels,
  93. out_channels = num_residual_hiddens,
  94. kernel_size = 3, stride = 1, padding = 1, bias = False),
  95. nn.ReLU(True),
  96. nn.Conv2d(in_channels = num_residual_hiddens,
  97. out_channels = num_hiddens,
  98. kernel_size = 1, stride = 1, bias = False)
  99. )
  100. def forward(self, x):
  101. return x + self._block(x)
  102. class ResidualStack(nn.Module):
  103. def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
  104. super(ResidualStack, self).__init__()
  105. self._num_residual_layers = num_residual_layers
  106. self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
  107. for _ in range(self._num_residual_layers)])
  108. def forward(self, x):
  109. for i in range(self._num_residual_layers):
  110. x = self._layers[i](x)
  111. return F.relu(x)
  112. class Encoder(nn.Module):
  113. def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
  114. super(Encoder, self).__init__()
  115. self._conv_1 = nn.Conv2d(in_channels = in_channels,
  116. out_channels = num_hiddens//2,
  117. kernel_size = 4,
  118. stride = 2, padding = 1)
  119. self._conv_2 = nn.Conv2d(in_channels = num_hiddens//2,
  120. out_channels = num_hiddens,
  121. kernel_size = 4,
  122. stride = 2, padding = 1)
  123. self._conv_3 = nn.Conv2d(in_channels = num_hiddens,
  124. out_channels = num_hiddens,
  125. kernel_size = 3,
  126. stride = 1, padding = 1)
  127. self._residual_stack = ResidualStack(in_channels = num_hiddens,
  128. num_hiddens = num_hiddens,
  129. num_residual_layers = num_residual_layers,
  130. num_residual_hiddens = num_residual_hiddens)
  131. def forward(self, inputs):
  132. x = self._conv_1(inputs)
  133. x = F.relu(x)
  134. x = self._conv_2(x)
  135. x = F.relu(x)
  136. x = self._conv_3(x)
  137. return self._residual_stack(x)
  138. class Decoder(nn.Module):
  139. def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
  140. super(Decoder, self).__init__()
  141. self._conv_1 = nn.Conv2d(in_channels=in_channels,
  142. out_channels=num_hiddens,
  143. kernel_size=3,
  144. stride=1, padding=1)
  145. self._residual_stack = ResidualStack(in_channels=num_hiddens,
  146. num_hiddens=num_hiddens,
  147. num_residual_layers=num_residual_layers,
  148. num_residual_hiddens=num_residual_hiddens)
  149. self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
  150. out_channels=num_hiddens//2,
  151. kernel_size=4,
  152. stride=2, padding=1)
  153. self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2,
  154. out_channels=3,
  155. kernel_size=4,
  156. stride=2, padding=1)
  157. def forward(self, inputs):
  158. x = self._conv_1(inputs)
  159. x = self._residual_stack(x)
  160. x = self._conv_trans_1(x)
  161. x = F.relu(x)
  162. return self._conv_trans_2(x)
  163. class Model(nn.Module):
  164. def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
  165. num_embeddings, embedding_dim, commitment_cost, decay=0):
  166. super(Model, self).__init__()
  167. self._encoder = Encoder(3, num_hiddens,
  168. num_residual_layers,
  169. num_residual_hiddens)
  170. self._pre_vq_conv = nn.Conv2d(in_channels = num_hiddens,
  171. out_channels = embedding_dim,
  172. kernel_size = 1,
  173. stride = 1)
  174. if decay > 0.0:
  175. self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
  176. commitment_cost, decay)
  177. else:
  178. self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
  179. commitment_cost)
  180. self._decoder = Decoder(embedding_dim,
  181. num_hiddens,
  182. num_residual_layers,
  183. num_residual_hiddens)
  184. def forward(self, x):
  185. # x.shape: B(256) C(3) H(32) W(32)
  186. z = self._encoder(x)
  187. z = self._pre_vq_conv(z)
  188. loss, quantized, perplexity, _ = self._vq_vae(z)
  189. x_recon = self._decoder(quantized) # decoder解码还原图像 B(256) C(3) H(32) W(32)
  190. return loss, x_recon, perplexity

完整代码参考:liujf69/VQ-VAE

3--部分细节解读:

重构损失计算:

        计算源图像和重构图像的MSE损失

  1. vq_loss, data_recon, perplexity = self.model(data)
  2. recon_error = F.mse_loss(data_recon, data) / self.data_variance

VQ量化损失计算:

        inputs表示Encoder的输出,quantized是Codebook中与 inputs 最接近的向量;

  1. # Loss
  2. e_latent_loss = F.mse_loss(quantized.detach(), inputs) # 论文中损失函数的第三项
  3. q_latent_loss = F.mse_loss(quantized, inputs.detach()) # 论文中损失函数的第二项
  4. loss = q_latent_loss + self._commitment_cost * e_latent_loss

Decoder的梯度复制到Encoder中:inputs是Encoder的输出,quantized是Decoder的输入;

quantized = inputs + (quantized - inputs).detach() # 梯度复制

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号