当前位置:   article > 正文

通过代码学习 VQ-VAE_vqvae

vqvae

VQ-VAE(Vector Quantised Variational AutoEncoder,矢量量化变分自动编码)是【1】提出的一种离散化VAE方案,近来【2】应用VQ-VAE得到了媲美于BigGan的生成模型。由此可见, VQ-VAE 有着强大的潜力,且【1】和【2】皆为DeepMind的作品,让我们通过代码来认识它,学习它。

一、简介

光看论文一知半解,需要看看它的实现。我在GitHub中找到一个很简单的代码【3】,不妨一起研究研究。以下叙述是结合【3】的实现一起叙述的。
VQ-VAE属于VAE范畴,它有着与一般VAE都有的Encoder、code(编码)和Decoder,而不同之处在于其code并不是由Encoder直接输出得到,而是经过了一个矢量量化后才得到的,其结构图如下:
在这里插入图片描述
图1 VQ-VAE结构图【3】
在这里插入图片描述
图2 VQ-VAE数据流图【1】

结合 图1、图2 叙述其工作流程

  1. 输入x,其数据结构为[B,3,32,32],由于【3】采用了CIFAR10作为训练集,因此输入参数如此,B是batch的数量;
  2. 经过Encoder,得到 Z e ( x ) Z_e(x) Ze(x), 其结构为 [B, C=D, H, W],其中C是指编码器的Conv网络输出的Channels 的数量,而D是指矢量量化中矢量的维度,也就是后续查表(Embedding)所存储矢量的维度,另外,H,W表示输入图像经编码器处理后的长和宽,本例中,编码器输入是32 * 32,输出时为8 * 8,即H=8, W=8;
  3. Z e ( x ) Z_e(x) Ze(x) 变形为 [B * H * W, D],即每一个图片有 H*W 个编码,每个编码是D维,计算这些编码(B * H * W)与 Embedding 中 K 个矢量(在【3】中 K=512,表示矢量量化编码的矢量个数)之间的距离,通过最近邻算法构成如下映射:
    q ( z = k ∣ x ) = { 1 for  k = arg ⁡ min ⁡ j ∥ Z e ( x ) − e j ∥ 2 0 otherwise ( 1 ) q(z=k|x)=\left\{ 1amp;for k=argminjZe(x)ej20amp;otherwise
    \right. \qquad (1)
    q(z=kx)={ 10for k=argminjZe(x)ej2otherwise(1)

    公式(1)表示当输入为 x x x 时, z = k z=k z=k 的概率是 :1)当 k k k 是矢量序列 { e 1 , e 2 , ⋯   , e K } \{e_1,e_2,\cdots,e_K\} { e1,e2,,eK}中与 Z e ( x ) Z_e(x) Ze(x) 最近的矢量的下标时,条件概率为1;2)否则为0。这里的矢量距离度量采用常见的欧拉距离 ∥ ⋅ ∥ 2 \Vert \cdot \Vert_2 2,公式(1)便是最近邻算法的实现。
    z q ( x ) = e k  where  k = arg ⁡ min ⁡ j ∥ Z e ( x ) − e j ∥ 2 ( 2 ) z_q(x)=e_k \ \text{where} \ k=\arg\min_j \Vert Z_e(x)-e_j\Vert_2 \qquad(2) zq(x)=ek where k=arg
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/72998
推荐阅读
相关标签
  

闽ICP备14008679号