当前位置:   article > 正文

解密 Sora 背后的魔法——Diffusion Transformer_解密sora diffusion transformer

解密sora diffusion transformer

解密 Sora 背后的魔法——Diffusion Transformer

在这里插入图片描述

Diffusion Transformer, 简称 DiT,来自 William Peebles 和 Saining Xie 的论文Scalable Diffusion Models with Transformers

DiT 影响了其他基于 Transformer 的扩散模型的发展,如 PIXART-αSora,以及最近发布的 Stable Diffusion 3。在 DiT 之前,几乎所有早期的扩散模型都基于 U-Net 卷积架构。随着 Sora 惊艳世界,Transformer 正逐渐替代 U-Net,成为图像和视频生成领域的新宠。本文将带大家揭示 Sora 背后的秘密,一同探索 Diffusion Transformer 的原理。

扩散模型

Diffusion Transformer 这是一个复杂的高级算法,要想弄懂 DiT 需要读者熟悉 AI 中的一些常见概念,特别是对图像生成要有一定的了解。

如果你已经熟悉这个领域,本节将帮助你复习这些概念,从而方便你更好得理解后面得内容。

如果你对图像生成不是很了解,推荐你阅读我之前的系列文章《Stable Diffusion 保姆级教程》《Diffusion Model 深入剖析》《Stable Diffusion 超详细讲解》《Stable Diffusion — ControlNet 超详细讲解》。这些文章涵盖了许多扩散模型和相关技术,其中一些我们将在这里重新讨论。

扩散公式

在这里插入图片描述

直观上讲,扩散模型的工作原理是:首先获取图像,引入噪声(通常是高斯噪声),然后训练一个神经网络通过预测被添加的噪声或噪声的协方差矩阵来逆转这一增加噪声的过程。引入的噪声量由时间帧变量 t t t 控制;其中在 t = 0 t=0 t=0 时, x 0 x_0 x0 表示原始图像,而在 t = 1000 t=1000 t=1000 时, x 1 000 x_1000 x1000 几乎是纯噪声。

实际操作中,对于每一个时间帧 t t t,我们从高斯分布中抽样 x t x_t xt,条件依赖于 x t − 1 x_{t-1} xt1
q ( x t ∣ x t − 1 ) : = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) := \mathcal{N}(x_t; \sqrt {1-\beta_t}x_{t-1}, \beta_tI) q(xtxt1):=N(xt;1βt xt1,βtI)
这里可以通过以下方式重参数化:
x t = 1 − β t x t − 1 + β t ε t − 1 x_t = \sqrt{1-\beta_t}x_{t-1}+\sqrt{\beta_t}\varepsilon_{t-1} xt=1βt xt1+βt εt1
其中, ε t ∼ N ( 0 , I ) \varepsilon_t \sim \mathcal{N}(0, I) εtN(0,I) β t \beta_t βt 项表示预定的偏差进度。为了根据 x t − 1 x_{t-1} xt1生成 x t x_t xt,我们从多元标准正态分布中抽样 ε t \varepsilon_t εt,并应用上述方程。这种逐步增加噪声的方法被称为前向过程。

幸运的是,为了生成 x t x_t xt,没有必要生成所有之前的 x t − 1 x_{t-1} xt1;可以直接使用以下公式:
x t = α t ˉ x 0 + 1 − α ˉ t ε t x_t =\sqrt{\bar{\alpha_t}}x_0+\sqrt{1-\bar{\alpha}_t}\varepsilon_t xt=αtˉ x0+1αˉt εt
其中 α t ˉ = ∏ s = 1 t ( 1 − β s ) \bar{\alpha_t} = \prod_{s=1}^t(1-\beta_s) αtˉ=s=1t(1βs)。这个公式的完整推到过程参见《Diffusion Model 深入剖析》

一旦我们确定了噪声添加的方法,我们就训练一个模型来预测添加的噪声。在训练期间,我们抽取一批图像和对应的 t t t 值,根据 t t t 值添加噪声,然后将这些加了噪声的图像及其 t t t 值输入模型。模型被训练用来预测噪声,以最小化损失函数,通常是实际添加的噪声与模型预测的噪声之间的均方误差。
L s a m p l e ( θ ) = ∥ ε θ ( x t ) − ε t ∥ 2 2 \mathcal{L}_{sample}(\theta) = \Vert \varepsilon_\theta(x_t)-\varepsilon_t \Vert_2^2 Lsample(θ)=εθ(xt)εt22
需要注意的是,这里的 ε θ \varepsilon_\theta εθ 是我们的神经网络,虽然没有显示出来,但它也将时间帧 t t t 的值作为输入。

对于图像生成,我们执行逆过程,从纯噪声开始,使用以下条件分布迭代采样:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(x_{t-1}|x_{t}) = \mathcal{N}(x_{t-1};\mu_\theta(x_t, t),\sum_\theta (x_t, t)) pθ(xt1xt)=N(xt1;μθ(xt,t),θ(xt,t))
其中
μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ε θ ( x t , t ) ) \mu_\theta(x_t,t) = \frac{1}{\sqrt{\alpha_t}}\Big(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\varepsilon_\theta(x_t,t)\Big) μθ(xt,t)=αt 1(xt1αˉt βtεθ(xt,t))
并且 ∑ θ ( x t , t ) \sum_\theta (x_t, t) θ(xt,t) 被设置为对角矩阵。在去噪扩散概率模型(Denoising Diffusion Probabilistic Models,简称 DDPM)的背景下,这个矩阵是固定的;而对于改进的去噪扩散概率模型(Improved Denoising Diffusion Probabilistic Models,简称 iDDPM),这个矩阵是通过学习得到的。DiT 使用的是 iDDPM。

有了训练好的模型,我们可以从纯噪声开始,并尝试一步去除所有噪声,以获得类似于训练数据集中的样本。然而采用逐步迭代的过程,在此过程中部分去除噪声,并引入一些“新鲜”的噪声(朗之万动力学),会产生更好的结果。逆向过程的具体细节由采样策略决定。

无分类引导

在实践中,我们很少试图在没有任何控制的情况下生成图像。最常见的方法是通过文本提示来引导图像生成,即我们常说的文本到图像转换(text-to-image)。这种方法的一个简化版本是类条件化(class conditioning),其中用类标签 c c c 作为提示。Diffusion Transformer 使用的就是这种类型的提示。

不幸的是,模型有时会忽略我们的提示,无论是文本提示还是其他形式的提示。为了应对这一点,采用了一种被称为“无分类引导(Classifier-free guidance)”的技术,以确保模型更加紧密地遵循我们的指示。

你可能会好奇为什么它被称为“无分类”引导。这种方法与分类引导不同,分类引导依赖于表示 x t x_t xt log ⁡ p ( c ∣ x t , t ) \log p(c | x_t, t) logp(cxt,t) 的梯度,其中 p ( c ∣ x t , t ) p(c | x_t, t) p(cxt,t) 表示图像 x t x_t xt 在时间帧 t t t 属于类 c c c 的可能性,并由分类器估计。函数的梯度指向最陡峭上升的方向。因此,通过增加这个梯度并赋予一定的权重,生成的图像与期望类别对齐的机会会增加,至少根据我们的分类器看是这样。

这种方法的一个主要挑战是需要一个分类器。如果我们的数据集类别没有预训练的分类器,我们必须首先训练一个。更棘手的是,即使存在合适的分类器,它也不太可能专门为我们的图像类型设计,因为预训练的分类器是为了分类“干净”的图像,而不是被高斯噪声遮蔽的图像。那么当我们想基于像“一个穿着芭蕾舞裙的小萝卜插图在遛狗”这样的文本提示或甚至非文本条件来引导我们的生成时该怎么办呢?

无分类引导有效地解决了这些障碍。它的操作方法是在训练期间偶尔用一个表示提示缺失的可学习嵌入替换提示嵌入。在推理过程中,这允许我们计算两个噪声估计:一个带有提示的,一个不带提示的。我们可以应用以下公式:
ε θ ^ ( x t , c ) = ε θ ( x t , ∅ ) + s ⋅ ( ε θ ( x t , c ) − ε θ ( x t , ∅ ) ) \hat{\varepsilon_\theta}(x_t, c) = \varepsilon_\theta(x_t, \empty)+s \cdot (\varepsilon_\theta(x_t,c)-\varepsilon_\theta(x_t,\empty)) εθ^(xt,c)=εθ(xt,)+s(εθ(xt,c)εθ(xt,))
从公式中我们可以看出,将引导尺度 s s s 设置为 1 1 1 时,噪声估计将保持在无引导过程中的状态(标准条件)。将 s s s 增加到 1 1 1 以上,会使未条件化的噪声估计 ε θ ( x t , ∅ ) \varepsilon_\theta(x_t, \empty) εθ(xt,) 更接近于与特定提示对齐的状态,向 ε θ ( x t , c ) − ε θ ( x t , ∅ ) \varepsilon_\theta(x_t,c)-\varepsilon_\theta(x_t,\empty) εθ(xt,c)εθ(xt,) 的方向移动。

尽管无分类引导通常会导致样本质量的提升,但它也需要在每次评估时做出两种噪声预测——一种带有提示的,一种是不带提示的——这实际上使得计算工作量翻倍。此外,在生成图像的质量与多样性之间存在权衡:随着 s s s 的增加,视觉保真度提高,但样本的多样性降低。在 William Peebles 和 Saining Xie 的论文中,作者在生成基于特定类别标签的图像时使用的引导尺度 s = 4 s=4 s=4

潜在扩散模型

在这里插入图片描述

关于潜在扩散模型(Latent Diffusion Models),我在《Stable Diffusion 超详细讲解》中详细的讲解。在这里我只提供简要的直观理解。

前面提到的扩散公式是在图像空间中操作的。与可以使用相对较低分辨率的分类模型不同,在我们的案例中,分辨率受到我们打算生成的图像分辨率的限制,例如 1024 × 1024 1024 \times 1024 1024×1024。此外,正如我们所见,逆向过程是迭代式的,这意味着为了生成一张图像,我们必须多次使用计算量较大的模型进行推理。使用 Transformer 而不是像 U-Net 这样的卷积网络进一步加剧了这个问题,因为注意力机制的规模是按二次方而不是线性增长的。此外,我们还会看到,从图像生成的 token 数量也是按二次方增长的,这意味着如果我们将分辨率从 256 × 256 256 \times 256 256×256 提高到 512 × 512 512 \times 512 512×512,如果之前的 token 数量是 T T T,它将变成 T 2 T^2 T2,如果注意力层中的操作数量是 O ( T 2 ) O(T^2) O(T2),它将变成 O ( T 4 ) O(T^4) O(T4)

解决这个问题基本上有两种方法:

  1. DALL·E 2(OpenAI)或 Imagen(Google)这样的模型采用的是在扩散过程中使用相对较低的分辨率,然后应用一个或多个超分辨率模型。
  2. 放弃图像空间,转而在潜在空间中工作。

第二种方法带来了潜在扩散模型。我们可以将潜在空间视为一种压缩图像的方式,它保持了其语义内容,只丢弃了一些边缘细节。能够以这种方式压缩然后解压图像的模型是变分自编码器(VAE)

在训练期间,每幅图像通过 VAE 编码器被压缩;在推理期间,只生成一个含噪声的张量 z z z。如果所需的分辨率是 256 × 256 × 3 256 \times 256 \times 3 256×256×3 z z z 可能是 32 × 32 × 4 32 \times 32 \times 4 32×32×4。注意,通道数量不再局限于保持在 3 3 3,因为不再在图像空间中工作。在生成过程中,从 z z z 开始,逐渐从中去除噪声,在最后一步中,使用 VAE 解码器将“干净”的 z z z 解压为最终图像。

Diffusion Transformer 设计空间

在这一部分,我假设你已经熟悉标准的自注意力机制和正弦位置编码的概念,这些都是与 Transformer 架构的基本概念。如果这些概念对你来说很陌生,或者你希望复习一下,推荐阅读:《深度解析 Transformer 和注意力机制(含完整代码实现)》

在这里插入图片描述

Patchify

在这里插入图片描述

Transformer 模型接受的输入是一个序列,或者更确切地说是一个集合(序列中的顺序仅通过位置编码/嵌入给出)。因此,我们必须执行的第一个操作是将 z z z(我们的潜在张量,在这个例子中是 32 × 32 × 4 32 \times 32 \times 4 32×32×4)转换成一系列的 token。Patchify 非常简单;它将 z z z 划分为一个网格 ( 32 / p ) × ( 32 / p ) × 4 (32 / p)\times(32 / p)\times 4 (32/p)×(32/p)×4,其中每个网格元素为 p × p × 4 p \times p \times 4 p×p×4; 然后线性投影成 1 × d 1 \times d 1×d,其中 d d d 是一个超参。在实践中,这可以通过应用卷积来轻松完成,卷积核和步长等于 p p p,输出通道数量等于 d d d,从而为每个批次元素获取一个大小为 ( 32 / p ) × ( 32 / p ) × d (32 / p) \times(32 / p) \times d (32/p)×(32/p)×d 的张量。如果我们定义 T = 3 2 2 / p 2 T=32^2 / p^2 T=322/p2,通过重新排列我们拥有的元素,对于大小为 N N N 的批次,得到一个大小为 N × T × d N \times T \times d N×T×d 的张量。 p p p 减半会使 T T T 增加 4 倍,从而使 Transformer 的总浮点运算次数(GFLOPS )至少增加 4 倍。在 William Peebles 和 Saining Xie 的论文中,作者尝试使用 p = 2 , 4 , 8 p=2,4,8 p=2,4,8;较小的 p p p 值产生更好的结果,但计算量更大。

在这一点上,回想一下注意力机制本身并不知道元素的顺序,我们为所有输入 token 添加位置编码。假设你已经熟悉一维正弦位置编码,我们将以标准的一维方式对网格中每个 toten 的 x x x 坐标进行编码。对 y y y 坐标应用相同的过程,然后将这两种编码对每个 token 进行连接。这种方法比对“平展”的图像应用标准的一维位置编码更受欢迎,因为后者会导致网格中两个相邻位置具有显著不同的编码。通过保持输入的二维性质,我们保持了相邻 token 具有相似的编码。

DiT 块设计

这里我假设你对 Transformer 架构有所了解,因此不会深入讨论层规范化(Layer Norm)、多头自注意力(Multi-Head Self-Attention)或逐点前馈(Pointwise Feedforward)这样的标准转换。这里重点介绍一个 DiT 特有的内容:我们如何实现类条件化?

实现类条件化的方法有很多,这里我只介绍 William Peebles 和 Saining Xie 实验中发现的最有效的方法:adaLN-Zero

首先,我们需要将 t t t c c c 转换成两个嵌入。对于类标签 c c c,这需要简单地为每个类初始化一个在训练期学到的嵌入向量。对于 t t t,情况稍微复杂一些:对 t t t 应用正弦编码,然后通过一个小型的多层感知器(MLP)进一步转换,该感知器由一个线性层、一个 SiLU 激活层和另一个线性层组成。这两个嵌入随后连接成条件向量。

这里简单描述一下 adaLN 的工作原理。首先,条件处理经过 SiLU 转换,然后进行线性投影;这种转换称为 adaLN_modulation。adaLN_modulation 的输出必须分成六个不同的部分,每个部分的维度为 d d d(与我们的条件相同)。这些向量代表了用于 DiT 块的不同部分缩放和移动输入的 γ 1 , β 1 , α 1 , γ 2 , β 2 , α 2 \gamma_1, \beta_1, \alpha_1, \gamma_2, \beta_2, \alpha_2 γ1,β1,α1,γ2,β2,α2

转换的顺序如下:

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class DiTBlock(nn.Module):
    ...

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  1. 输入 token x x x 通过层规范化(Layer Norm)进行标准化:self.norm1(x)
  2. 使用 modulate(...) 函数,缩放和移动输出,其中 shift_msa, scale_msa 分别代表 γ 1 , β 1 \gamma_1, \beta_1 γ1,β1
  3. 应用多头自注意力机制:self.attn(...)
  4. 输出通过 gate_msa.unsqueeze(1) 缩放,代表 α 1 \alpha_1 α1
  5. 结果被加回原始的 x x x,从而使用残差连接(类似于 ResNet)
  6. 重复类似的过程,但使用逐点前馈 self.mlp(...),而不是 self.attn(...)

adaLN-Zero 与 adaLN 的不同之处在于,adaLN_modulation 的权重初始化为零。这意味着最初,在上述前向方法中, x x x 并没有以任何方式被改变,而网络只是逐渐学习最佳的缩放和移动参数。

Transformer 解码器

最后,解码器由一个线性层加上类似于我们已经看到的规范化组成。这个线性层将最终输出 x x x 的维度从 N × T × d N \times T \times d N×T×d 转换为 N × T × p 2 × 2 C N \times T \times p^2 \times 2C N×T×p2×2C,其中 C C C 代表输入通道。需要注意的是,只有在我们除了预测噪声外还要预测对角协方差矩阵 Σ \Sigma Σ 时,它才是 2 C 2C 2C;否则,它就只是 C C C

然后输出被“unpatchified”(重塑),以产生预测的噪声或噪声与预测的 Σ \Sigma Σ 结合的结果。

总结

本文可以视为 《Sora 技术实现》这篇文章的后续,介绍了论文《Scalable Diffusion Models with Transformers》的基本部分,同时也提供了一些有用的背景信息。

最后分享一个有趣的轶事,虽然《Scalable Diffusion Models with Transformers》这篇论文正变得日益重要,但它最初在 2023 年的 CVPR(计算机视觉和模式识别会议)上被拒绝了。这可能印证了那句话:“这个世界就是个草台班子”,即便是专家也很难预测什么会产生影响,而且科学会议的审查过程远非完美。希望 AI 可以让这个世界不再那么潦草……

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/930133
推荐阅读
相关标签
  

闽ICP备14008679号