赞
踩
官方代码仓库为 https://github.com/facebookresearch/DiT,下面代码的具体位置在 /path/to/DiT/models.py
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(t) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
c = t + y # (N, D)
for block in self.blocks:
x = block(x, c) # (N, T, D)
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
与下面代码中 forward 函数内对应的变量在 DiT Block 中的位置。
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
在这个 DiTBlock
类中,shift_msa
、scale_msa
、gate_msa
、shift_mlp
、scale_mlp
和 gate_mlp
是从 adaLN_modulation(c)
这一步中得到的,它们在具体功能上是有所区别的,虽然它们是通过同一个输入 c
生成的。
shift_msa
和 scale_msa
这两个变量与 Multi-Head Self-Attention (MSA)
模块的自适应层归一化(adaptive LayerNorm, adaLN)有关:
shift_msa
: 这个变量用于平移 LayerNorm
的输出,也就是在归一化的基础上加上一个偏置。它在调节 MSA 模块的激活输出时用作偏移量。scale_msa
: 这个变量用于缩放 LayerNorm
的输出,即对归一化的结果乘以一个比例因子。它控制了 MSA 模块中激活的放大或缩小程度。gate_msa
: 这个变量是作为一个门控(gate)信号,作用于 MSA 模块的输出上。它决定了 MSA 模块输出在累加到 x
之前的权重。如果 gate_msa
很小,那么这个输出会被抑制;如果 gate_msa
接近1,则输出会如常累加。
shift_mlp
和 scale_mlp
这两个变量与 Pointwise Feedforward (MLP)
模块的自适应层归一化(adaLN)有关,类似于 shift_msa
和 scale_msa
,但它们作用在 MLP 模块上:
shift_mlp
: 用于平移 LayerNorm
的输出,在 MLP 模块中作为偏移量。scale_mlp
: 用于缩放 LayerNorm
的输出,在 MLP 模块中控制激活的放大或缩小。gate_mlp
: 类似于 gate_msa
,但它控制的是 MLP 模块的输出。它决定了 MLP 模块输出在累加到 x
之前的权重。
在 adaLN_modulation(c)
中,c 经过一个 nn.Linear 层(即 nn.Linear(hidden_size, 6 * hidden_size, bias=True)),然后被 chunk(6, dim=1) 分成六个部分,分别得到 shift_msa、scale_msa、gate_msa、shift_mlp、scale_mlp 和 gate_mlp。
虽然这些变量来自于同一个线性层的输出,但由于 nn.Linear 层的权重在训练过程中是可学习的,并且是随机初始化的,因此这些权重会在训练过程中被更新为不同的值。
那么为什用 adaLN-Zero 来代替 Cross-Attention 呢?主要是因为计算资源。(DiT 原文提到 Cross-attention adds the most Gflops to the model, roughly a 15% overhead.)
什么是adaLN-Zero Block?
adaLN-Zero Block是一种改进版的adaLN(Adaptive Layer Normalization)模块,主要用于扩散模型(Diffusion Model)中。它的核心思想是通过初始化技巧和引入额外的缩放参数,来加速模型训练并提高生成样本的质量。
为什么引入adaLN-Zero Block?
adaLN-Zero Block的工作原理
与传统adaLN的区别
为什么有效?
最后也附上原文,便于对照理解。
Peebles, William, and Saining Xie. “Scalable diffusion models with transformers.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023. ↩︎
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。