赞
踩
这里要介绍的DALL-E[1]是OpenAI的多模态预训练模型,它的最显著的效果是在文本到图像的生成上。2021 年初,OpenAI 发布了一款名为 DALL-E 的图像生成模型,该模型说白了就是可以根据用户提供的文本描述自动生成对应的图像。一个例子是图1的DALL-E根据输入“牛油果形状的扶手椅”生成的图像,它足以达到以假乱真的效果,生成的内容不仅逼真合理,甚至可以一定程度上启发人类设计师。DALL-E通过120亿参数的模型,在2.5亿图像文本对上训练完成。它是一个两阶段的模型:它的第一个阶段是离散变分自编码器(Discrete Variance Auto-Encoder,dVAE),用于生成图像的token。它的第二个阶段是混合了图像和文本特征的,以Transformer为基础的生成模型。在DALL-E中,它使用了非常多优化模型准确率的技巧和提升训练效率的优化,下面我们来逐一介绍之。
图1:DALL-E根据输入“牛油果形状的扶手椅”生成的图像
我们知道,由于图像特征的密集性和冗余性,它是不能直接提供给Transformer进行训练的。目前主流的方式,例如ViT,Swin-Transformer等都是将图像的Patch作为模型的输入,然后通过一个步长等于Patch大小的大卷积核得到每个Patch的特征向量。DALL-E提供的方案是使用一个离散的变分自编码器(dVAE)将大小为 256×256 的RGB图像压缩到大小为 32×32 的,通道数为 8,192 的one-hot token的分布(注意这个one-hot的形式,它很重要),变分自编码器的架构如图2所示。换句话说,阶段1的作用是将图像映射到一个大小为 8,192 的图表中。这里通道数为 8,192 的one-hot向量可以看做是一个词表,它的思想和是通过离散VAE,实现图像特征空间想文本特征空间的映射。
DALL-E的离散VAE的编码器和解码器都是基于残差网络[4]构建的,DALL-E保持了残差网络的基础结构,但也有其针对性的调整,它的核心修改如下:
图3:DALL-E的阶段2的先验分布学习
阶段2的输入是拼接的文本特征和图像特征以及鸽子的位置编码等信息。图4是一个模型输入的例子,在这个示例中,文本嵌入的长度是6,图像嵌入的长度是2。文本的输入是文本的嵌入和文本的位置编码,图像的输入是图像的标志嵌入,行位置编码以及列位置编码。这些编码的长度均是$3968$。
图4:DALL-E阶段2的模型输入示例图
图5:DALL-E的阶段2的三种稀疏自注意力机制示意图,这里文本的token长度是6,图像的长度是16。(a)是行注意力,每个标志只关注top-5的相关标志,(b)是列注意力,(c)是列注意力的转置,它能够更好的利用GPU,(d)是卷积自注意力
在DALL-E的官网示例中,它给了很多文本生成图像的案例。图像生成过程如图6所示,它首先将输入文本编码成特征向量,然将特征向量送入到自回归的Transformer中生成图像的token,再后将图像的token送入到dVAE的解码器中得到生成图像,最后通过CLIP对生成样本进行评估,得到最终的生成结果。
图6:DALL-E的图像生成过程
为了提升GPU的计算效率,DALL-E的大量参数以及激活都使用了16位的低精度存储。这种低精度模型的最大挑战是梯度下溢(underflow)的问题,也就是计算的梯度值超出了16位浮点数能表示的最低值。DALL-E的使用了大量的技术来解决这一问题(论文附录D),这里重点介绍它的最重要的一个点:每个残差块的梯度缩放(per-resblock gradient scaling)。
传统的低精度训练通过将梯度值限制在一个模型能表示的范围内来避免梯度下溢。但是这种粗暴的限制每一个梯度的范围的方法并不适合DALL-E这种文本到图像更复杂的任务,它需要更过的精度表示。DALL-E的策略是对每个残差块使用单独的梯度缩放比例,因此这里将它命名为混合精度训练。它的核心点有3个:
DALL-E的一个残差块的混合精度训练如图7所示,其中实线表示前向传播的计算顺序,虚线表示反向传播的计算流程。在前向计算时,对于单位映射,我们先将其缩放到16位,当进行完卷积运算时,我们再将其放到到32位。在进行反向运算时,我们先对梯度尽心过滤和缩放,其中过滤操作会将所有NaN和Inf的值置0,经过卷积操作权值的更新后再将其放大到32位。
图7:DALL-E的混合精度训练
DALL-E的模型即使使用16位的精度来存储,也要占用大约24G的显存,这超过了他们训练环境的单卡(NVIDIA V100 16G)的硬件显存,这里他们使用了参数分片(Parameter Sharding)[7]来解决显存不足的问题。
在进行模型的参数分片训练时,一个问题是不同机器的通信问题,它们之间的带宽是远小于同一台机器的不同显卡之间的带宽的,这成为了多机多卡训练的一个瓶颈。这里DALL-E使用了PowerSGD[8]压缩梯度来大幅降低带宽成本。
关于DALL-E的混合梯度以及分布式运算,他们更多是在训练效率上的提升,这里不会过多介绍,感兴趣的阅读DALL-E的论文以及它涉及的相关参考文献。
从上面的分析中我们可以看出DALL-E又是一个OpenAI风格的文章,模型创新不多,但是靠着庞大的数据和参数取得了令人惊叹的效果。文章中涉及的一些创新更多的是面向研发过程中遇到问题的对症下药。但技术上的创新匮乏并不能研发它对于深度学习领域的巨大贡献,最起码它证实了深度学习在大量参数和数据的前提下的无限可能性。
[1] 剖析 AIGC 关键模型 —— DALL-E - 知乎
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。