赞
踩
传统的扩散模型基于一个U-Net骨架,这篇文章提出了一种新的扩散模型结构,将U-Net替换为一个transformer,并将这种结构称为Diffusion Transformers (DiTs)。他们还发现,transformer的规模越大(通过Gflops衡量),生成的图片的质量越好(FID越低)。
如图2所示,DiT的规模越大,图片生成的质量越好(左图),和当前流行的扩散模型相比,DiT的计算效率也表现优异。
Diffusion Transformers (DiTs)是基于Vision Transformer (ViT)的模型,它的大体结构如图3所示,从左图可以看到,输入的噪音特征被分解为不同批,然后被若干个DiT块处理;右边的三张图展示了DiT块的详细结构,分别是三种不同的变体。
下面对DiT的各层进行分析:
Patchify. 从图3中可以看到,DiT的第一个层是Patchify,其将输入转化为
T
T
T个token序列。在这之后,作者使用标准ViT中基于频率的位置嵌入处理前面的token序列。而token序列的数量是由一个超参数
p
p
p决定的,
p
p
p减半导致
T
T
T翻四倍,并且导致整个transformer的GFlops至少翻四倍,如图4所示。
DiT block design. 在patchfiy层之后,几个transformer块处理输入token以及一些额外的条件信息,比如,类标签
c
c
c和时间步数
t
t
t。作者尝试了4种不同的ViT变体:
Model Size. 作者设置了四种规模的DiT:DiT-S, DiT-B, DiT-L and DiT-XL,结构复杂度依次增大。
Transformer decoder. 在经过最后的DiT块之后,使用tranformer decoder将输入tokens转化为和输入同等性状的噪音预测。
综上,作者探索了DiT设计空间中的patch_size、transformer架构(4种,in-context,cross-attention, adaptive layer
norm and adaLN-Zero blocks)和model size(4种,DiT-S, DiT-B, DiT-L and DiT-XL)。
DiT block design. 四个不同的DiT块:in-context (119.4 Gflops), cross-attention (137.6 Gflops),
adaptive layer norm (adaLN, 118.6 Gflops) or adaLN-zero (118.6 Gflops)中, adaLN-zero (118.6 Gflops) 取得最低的FID。其中,adaLN-zero相较于adaptive layer norm的提升,说明了恒等映射的好处。(后续的实验除非特别说明都是在adaLN-zero上做的)
Scaling model size and patch size. 模型size增大和patch zise减小,均会提高Gflops,降低FID。我们注意到,DiT-L 和DiT-XL的FID很接近,因为它们的Gflops也相对更接近。
DiT Gflops are critical to improving performance. 上面的图6再次说明了模型参数量的增大并不等同于DiT模型的图片质量提高,真正的关键是提高Gflops。比如,DiT S/2的表现和DiT B/4接近,因为小的batch size会增大Gflops,二者的Gflops接近,所以FID也接近。
Larger DiT models are more compute-efficient
小的DiT模型即便训练时间更长,相对于训练时间更短的大的DiT模型,其计算效率也是更差的。
这里,作者估计训练计算量的方式为model Gflops · batch size · training steps · 3。
和主流的扩散模型相比,DiT-XL/2 (即参数量最大,patch size最小的DiT)的表现最优。
扩散模型有一个比较特殊的点,在生成图片时,它可以通过增加调整采样步数,引入额外的增加的计算量,但是,这并不能弥补训练时模型计算量的差距,即大GFlops的DiT在采样步数少的情况下,仍然能比小GFlops的DiT在采样步数多的情况下,取得更低的FID。
Diffusion Transformers (DiTs)作为一种新的扩散模型,比基于U-Net的扩散模型表现更加优异。并且,其在模型复杂度提高的时候,能够有明显的性能提高,因此,使用更大规模的DiT有助于提高模型性能。此外,DiT也可以用于文生图生成任务。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。