当前位置:   article > 正文

图像生成模型王牌——Diffusion Transformers系列工作梳理

diffusion transformer 模型

图像生成模型是目前业内研究的焦点,而目前诸如Sora等前沿生成模型,其所基于的主体架构都是Diffusion Transformers(DiT)。Diffusion Transformers(DiT)是论文Scalable Diffusion Models with Transformers(ICCV 2023)中提出的,是扩散模型和Transformer的结合,也是Sora使用的底层生成模型架构,将Diffusion Transformers从图像生成扩展到了视频生成。这篇文章给大家总结了目前主要的几个DiT模型结构,带大家梳理DiT系列模型的核心。

1

DiT

在之前的图像生成扩散模型中,底层的网络结构一般都是U-Net。而本文基于Vision Transformer(ViT)中的Transformer图像分类模型结构,替代扩散模型中的U-Net,得到DiT模型,实现了更优的生成效果。

在输入部分,基本采用了和ViT相同的方法。对输入的图像分成多个patch,并转换成一个token序列,每个token拼接上相应的position embedding。这个底层的embedding序列作为后续DiT模块的输入。

1fceb48b278b4cfcbc94a549b0cd61bf.png

在扩散模型中,Transformer除了像ViT那样输入图像patch token序列,往往还要输入一些额外的信息,包括扩散模型中当前的生成时间步、文本信息的输入等,如何将这些信息输入到DiT中,文中尝试了几种方案。最简单的方法是将这些额外的embedding直接拼接到原始的序列上。第二种是将外部的embedding单独拼接成一个序列,和原始的图像patch序列额外做一个cross attention。第三种方法是修改Transformer中的layer normalization模块,将其替换成adaptive layer normalization,LN的均值和方差由外部embedding的加和生成。第四种是在第三种的基础上,引入了基于外部embedding生成的缩放因子,对multi-head attention的输出进行缩放。

0aaef7ec39e9656c84fe5a2a2faebe2e.png

在经过多层的DiT模型后,需要将预测的噪声结果还原出来,这里使用一个MLP作为Decoder,将DiT生成的结果映射到噪声预测结果。

上述就是DiT的整体结构,主要还是Vision Transformer。用这个DiT结构,替代扩散模型中的去噪模块,也就是噪声预测网络,就是DiT模型

从实验对比中可以看出,DiT的生成效果是超过基于U-Net等之前的SOTA模型的。

28835bb8c43e14ab9bcde95ecc07903d.png

97f763255b5d4c56db7d393887dbd156.png

2

U-ViT

U-ViT是另一个基于ViT的扩散模型网络。U-ViT也是将扩散模型中的噪声预测网络替换成Transformer结构,并且借鉴了U-Net等传统CV模型中的残差网络思路,每一层的输出都会通过龙skip connection加到更深层的网络中。此外,文中对一些模型结构也进行了尝试,包括残差网络怎么加,是直接拼接到深层+MLP还是add到生成;扩散步骤embedding怎么加入到U-ViT中;以及Transformer之后的卷积网络怎么加。

43a67219e1b0eb3b5b49c7361e30651f.png

3

MDT

MDT发表于论文Masked diffusion transformer is a strong image synthesizer(ICCV 2023),在DiT的基础上,引入了mask latent modeling,进一步提升了DiT的收敛速度和生成效果。

文中分析发现,DiT在学习过程中,并不能很好的学习各个语义单元之间的关系。为了解决这个问题,MDT引入了一个重构任务,对输入的图像的部分patch进行mask,然后使用一个Transformer模型在生成过程中,对这部分被mask掉的patch进行还原。在扩散模型中,每一层MDT输入被mask掉一部分的token序列,只根据这部分序列进行噪声预测。同时,使用一个Transformer网络来还原被mask掉的部分。通过这种方式,让模型在学习过程中强行学习patch之间的关系。同时通过position embedding的引入提升对mask token的还原能力。

由于在生成阶段,decoder在处理token的时候都是没有mask的,训练的时候是mask的,这种不一致会影响效果。因此文中采用side-interpolater,对被mask掉的部分使用side-interpolater的预测结果,融合上没被mask的结果,保证训练和预测阶段decoder的输入都是没有mask掉的。

14f8fd19c682a3bcd753bd2d6290763f.png

4

Diffit

Diffit是英伟达发表于论文Diffit: Diffusion vision transformers for image generation(2023)中的一种方法,也是Diffusion Transformer的一个变体,在模型结构上进行了改进。整体的结构类似于U-Net和Transformer的结合,通过增加downsample和upsample实现层次性的建模。

e16944c7370c9f21bd7d9a7fbd7e1664.png

Diffit在引入扩散步骤embedding的时候,采用了一种Time-dependent Self-Attention的方式,即将步骤embedding直接加入到输入token序列上,让self-attention在计算的过程中就考虑到扩散步骤的信息。在模型结构上,采用U-Shape的形式,Encoder部分每一层Transformer后做downsample,来提取不同分辨率下的图像信息,Decoder部分再逐渐upsample。

a7250b54ba5800523abe79e01fa97ee5.png

fd38a1cc8951f61512249b5a2f0133df.png

END

8930e7f8dfef575a7cb1a13181b816b5.png

分享

收藏

点赞

在看

58c5b7764904bb393e787292cad1ff0e.gif

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/930105
推荐阅读
相关标签
  

闽ICP备14008679号