赞
踩
最近在看图像生成相关论文,记录一下学习内容。感觉只看论文有点干巴,所以理论代码一对一上。
VQGAN (Vector Quantized Generative Adversarial Network) 是一种基于 GAN 的生成模型,可以将图像或文本转换为高质量的图像。
上图是论文的总体模型图。下面具体来看看如何实现的。
VQGAN整体模型需要两步训练。
如上图所示,从一张输入图片开始(一般是RGB图片) x ∈ R H × W × 3 x \in \mathbb{R}^{H\times W×3} x∈RH×W×3,其通过CNN Encoder编码后得到中间特征变量 z ^ ∈ R h × w × n z \hat z \in \mathbb{R}^{h\times w×n_z} z^∈Rh×w×nz。这时再引入一个codebook,注意,如果是普通的AutoEncoder,则会将 z ^ \hat z z^ 直接送入解码器中进行图像重建。而在VQVAE/VQGAN中,会将 z ^ \hat z z^进行进一步离散化编码成 z q ∈ R h × w × n z z_q\in \mathbb{R}^{h\times w×n_z} zq∈Rh×w×nz。
具体做法为:预先生成一个离散数值的codebook Z = { z k } k = 1 K , z k ∈ R n z \mathcal Z=\{z_k\}_{k=1}^{K},z_k \in \mathbb{R}^{n_z} Z={zk}k=1K,zk∈Rnz,在 z ^ \hat z z^ 的每一个编码位置都去 Z \mathcal Z Z中去寻找其距离最近的code,生成具有相同维度的变量。特别注意,这里 z ^ , z q \hat z,z_q z^,zq和 Z \mathcal Z Z中的单个编码特征的维度都为 n z n_z nz。这一步离散编码的过程就叫做“quantization”, 也就是上面的那个公式。
这样一来,就可以在已经数值离散化的
z
q
z_q
zq基础上使用CNN Decoder进行解码:
x
^
=
G
(
z
q
)
=
G
(
q
(
E
(
x
)
)
)
\hat x=G(z_q)=G(q(E(x)))
x^=G(zq)=G(q(E(x)))
整个过程的自监督损失如下:
L
V
Q
(
E
,
G
,
Z
)
=
∣
∣
x
−
x
^
∣
∣
2
+
∣
∣
s
g
[
E
(
x
)
]
−
z
q
∣
∣
2
+
∣
∣
s
g
(
z
q
)
−
E
(
x
)
∣
∣
2
\mathcal L_{VQ}(E,G,Z)=||x-\hat x||^2+||sg[E(x)]-z_q||^2+||sg(z_q)-E(x)||^2
LVQ(E,G,Z)=∣∣x−x^∣∣2+∣∣sg[E(x)]−zq∣∣2+∣∣sg(zq)−E(x)∣∣2其中,上式中的第一项
L
r
e
c
\mathcal L_{rec}
Lrec 为重建损失(reconstruction loss)
s
g
[
⋅
]
sg[·]
sg[⋅] 为梯度终止操作(stop-gradient operation),其目的在于保证神经网络梯度可以正常回传,而不受离散编码的影响。因此在codebook的搭建过程中,我们看到由
z
^
\hat z
z^得到
z
q
z_q
zq之后,先计算出公式中后两项损失,然后又增加了一步detach操作。
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
z_q = z + (z_q - z).detach()
这么一来,在其后面计算
L
r
e
c
\mathcal L_{rec}
Lrec,即公式的第一项中,
z
q
z_q
zq的梯度可以顺利复制到
z
^
\hat z
z^上,而不受离散编码过程的干扰。除了这个重建过程使用的自监督损失外,还加入了GAN中的对抗loss。文章里没有具体写对抗loss的类型。通过源码可以发现使用的是hinge loss。对于判别器而言,其损失函数可以笼统地表示为:
L
G
A
N
(
{
E
,
G
,
Z
}
,
D
)
=
l
o
g
D
(
x
)
+
l
o
g
(
1
−
D
(
x
^
)
)
\mathcal L_{GAN}(\{E,G,\mathcal Z\}, D)=logD(x)+log(1-D(\hat x))
LGAN({E,G,Z},D)=logD(x)+log(1−D(x^))
所以总的误差可以写成:
L
=
L
V
Q
+
λ
L
G
A
N
\mathcal L = \mathcal L_{VQ}+\lambda \mathcal L_{GAN}
L=LVQ+λLGAN
总结来说就是:
x
→
z
^
→
z
q
→
x
^
x\to \hat z\to z_q\to \hat x
x→z^→zq→x^
下面主要来看看这三部分的代码
CNN Encoder, CNN Decoder是一种基于UNet的代码结构,具体细节可以从原文中获取,这里不在细说
class Encoder(nn.Module): def __init__(self, args): super(Encoder, self).__init__() channels = [128, 128, 128, 256, 256, 512] attn_resolutions = [16] num_res_blocks = 2 resolution = 256 layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)] for i in range(len(channels)-1): in_channels = channels[i] out_channels = channels[i + 1] for j in range(num_res_blocks): layers.append(ResidualBlock(in_channels, out_channels)) in_channels = out_channels if resolution in attn_resolutions: layers.append(NonLocalBlock(in_channels)) if i != len(channels)-2: layers.append(DownSampleBlock(channels[i+1])) resolution //= 2 layers.append(ResidualBlock(channels[-1], channels[-1])) layers.append(NonLocalBlock(channels[-1])) layers.append(ResidualBlock(channels[-1], channels[-1])) layers.append(GroupNorm(channels[-1])) layers.append(Swish()) layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1)) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x)
具体的模块定义可以阅读源代码,这个都不难理解。
class Decoder(nn.Module): def __init__(self, args): super(Decoder, self).__init__() channels = [512, 256, 256, 128, 128] attn_resolutions = [16] num_res_blocks = 3 resolution = 16 in_channels = channels[0] layers = [nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1), ResidualBlock(in_channels, in_channels), NonLocalBlock(in_channels), ResidualBlock(in_channels, in_channels)] for i in range(len(channels)): out_channels = channels[i] for j in range(num_res_blocks): layers.append(ResidualBlock(in_channels, out_channels)) in_channels = out_channels if resolution in attn_resolutions: layers.append(NonLocalBlock(in_channels)) if i != 0: layers.append(UpSampleBlock(in_channels)) resolution *= 2 layers.append(GroupNorm(in_channels)) layers.append(Swish()) layers.append(nn.Conv2d(in_channels, args.image_channels, 3, 1, 1)) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x)
我最开始看的时候,最不明白的地方就是这个codebook,一直在想,这兄弟是哪蹦出来的。其实就是另外定义的一个网络,说白了甚至算不上一个网络就是一个
nn.Embedding()
,还是之前没看VQVAE的锅。
class Codebook(nn.Module): def __init__(self, args): super(Codebook, self).__init__() self.num_codebook_vectors = args.num_codebook_vectors self.latent_dim = args.latent_dim self.beta = args.beta self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim) self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors) def forward(self, z): z = z.permute(0, 2, 3, 1).contiguous() z_flattened = z.view(-1, self.latent_dim) d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ torch.sum(self.embedding.weight**2, dim=1) - \ 2*(torch.matmul(z_flattened, self.embedding.weight.t())) min_encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2) z_q = z + (z_q - z).detach() z_q = z_q.permute(0, 3, 1, 2) return z_q, min_encoding_indices, loss
经VQGAN得到的压缩图像与真实图像有一个本质性的不同:真实图像的像素值具有连续性,相邻的颜色更加相似,而压缩图像的像素值则没有这种连续性。
压缩图像的这一特性让寻找一个压缩图像生成模型变得异常困难。多数强大的真实图像生成模型(比如GAN)都是输出一个连续的浮点颜色值,再做一个浮点转整数的操作,得到最终的像素值。而对于压缩图像来说,这种输出连续颜色的模型都不适用了。而恰好,Transformer天生就支持建模离散的输出。在NLP中,每个单词都可以用一个离散的数字表示。Transformer会不断生成表示单词的数字,以达到生成句子的效果。
VQGAN的作者使用了自回归图像生成模型的常用做法,给图像的每个像素从左到右,从上到下规定一个顺序。有了先后顺序后,图像就可以被视为一个一维句子,可以用Transfomer生成句子的方式来生成图像了。在第i 步,Transformer会根据前i−1 个像素 s < i s_{<i} s<i生成第 i i i 个像素 s i s_i si.
来看具体实现——训练过程:
现在进入第二步,这篇论文毕竟是个图像生成的任务,注意之前的三个零件已经训练好不动了,现在我们需要得到一组排列好的code,送进CNN Decoder中来实现图像生成。那么这组code怎么来的?这就是Transformer发挥作用的地方了。该工作使用的Transformer模型为著名的GPT-2。迁移到VQGAN中,即可理解为先预测一个code,再一步步地通过已经预测好的code去推断下一个code。
code都是从训练好的codebook Z \mathcal Z Z中寻找,就像写文章一样,你有词典了,现在你要从词典中一个字一个字的写成一篇新文章
为了训练Transformer,
假设被替换后的code组合的索引为modified_indices,原本
z
q
z_q
zq的code索引为unmodified_indices,那么Transformer的学习过程即为:喂入modified_indices,通过训练学习重构出unmodified_indices。
L
t
r
a
n
s
f
o
r
m
e
r
=
E
x
∼
p
(
x
)
[
−
l
o
g
p
(
s
)
]
\mathcal L_{transformer}=\mathbb E_{x\sim p(x)}[-logp(s)]
Ltransformer=Ex∼p(x)[−logp(s)]
代码具体实现如下:
""" 首先得到由x前传得到的unmodified_indices """ sos_tokens = torch.ones(x.shape[0], 1) * self.sos_token # (B, 1), sos_token是一个整数,表示从第几个token开始预测,一般为0 mask = torch.bernoulli(self.pkeep * torch.ones(unmodified_indices.shape, device=unmodified_indices.device)) # (B, h*w), 元素都为0和1,0的是mask掉的元素,1是保留的元素(比例为pkeep) mask = mask.round().to(dtype=torch.int64) random_indices = torch.randint_like(indices, self.transformer.config.vocab_size) # (B, h*w), 生成一些任意的indices,用来填充被遮挡的部分 modified_indices= mask * unmodified_indices+ (1 - mask) * random_indices # (B, h*w), mask为1(未遮挡)部分仍然保留原始indices,mask为0(遮挡)部分用random_indices填充 modified_indices= torch.cat((sos_tokens, modified_indices), dim=1) # (B, h*w+1),将0放到第一个indice前面 targets = unmodified_indices logits, _ = self.transformer(modified_indices[:, :-1]) # logits: (B, h*w, num_codebook_vectors), 意思是h*w个indices处,预测出来的对应每一个codebook_vector的概率 """ 然后再由logits和targets之间计算交叉熵损失 """ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
注意这是训练的过程,不是生成的过程。在VQGAN无条件生成图片的过程中,没有任何先验条件,CNN Encoder直接被弃用。我们需要得到一组排列好的code,送进CNN Decoder中来实现图像生成。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。