赞
踩
1. 背景知识
1.1 Classifier Guidance 和 Classifier-free Guidance
2. 算法详解
1. 模块化
2. DiT模块
2.1 上下文条件(In-context conditioning)
2.2 交叉注意力块(Cross-Attention)
2.3 自适应层归一化块(Adaptive Layer Normalization,AdaLN)
2.4 adaZero-Block
3. 总结
在Sora[1]的技术报告中,作者指出Sora是一个Diffusion Transformer。这个Diffusion Transformer便是我们这里将要介绍的DiT[2]。相较于我们之前介绍的LDM[3],DiTs也是作用在潜空间,它最大的改进是将U-Net的CNN替换为了Transformer。同时DiT是一个可扩展的架构,而且样本质量和网络复杂度存在这强烈的相关性。
与DiTs最密切的算法是LDM,LDM最大的特点是将DDPM[4]的计算空间从图像空间改到了潜空间。而图像空间和潜空间的互相转换则通过VQ-VAE[5]的编码器和解码器。LDM采用了一个CNN和交叉注意力的混合结构,其中CNN用于对图像进行编码,交叉注意力用于将条件特征融入到模型中。而DiT则是将LDM的CNN完全替换为了Transformer。
Classifier Guidance是OpenAI在《Diffusion models beat gans on image synthesis》[6]中提出的思想,它使得扩散模型可以按照指定的类生成图像。Classifier Guidance可以通过Score function来解释,我们可以使用贝叶斯定理对条件生成概率进行分解,如式(1)。从中可以看出Classifier Guidance的条件生成只需要添加一个额外的Classifier梯度即可。(1)∇��log�(��∣�)=∇��log(�(��)�(�∣��)�(�))=∇��log�(��)+∇��log�(�∣��)−∇��log�(�)=∇��log�(��)⏟unconditional score +∇��log�(�∣��)⏟classifier gradient 我们可以添加一个权重项 � 来调整来灵活的控制unconditional score和classifier gradient的占比,如式(2)。
(2)∇��log�(��∣�)=∇��log�(�∣��)+�∇��log�(��)
从式(1)中我们也可以看出Classifer Guidance的几个问题,首先因为需要训练Classifier梯度项,这相当于要额外训练一个根据噪声得到类别标签的分类器,显然是一个非常困难的任务。此外这个分类器的结果反映到了生成梯度上,无疑会对生成效果产生一定程度的影响。
为了解决这个问题,Google提出了Classifier-free Guidance方案[7]。Classifier-free guidance的核心是通过一个隐式分类器来代替显式分类器,使得生成过程不再依赖这个显式的分类器,从而解决了Classifier Guidance的这几个问题。具体来讲,我们对式(1)进行移项,可得:(3)∇��log�(�∣��)=∇��log�(��∣�)−∇��log�(��)将式(3)代入到式(2)中,我们有(4)∇��log�(��∣�)=∇��log�(��)+�(∇��log�(��∣�)−∇log�(��))=∇��log�(��)+�∇log�(��∣�)−�∇��log�(��)=�∇��log�(��∣�)⏟conditional score +(1−�)∇��log�(��)⏟unconditional score 根据式(4),我们的分类器由conditional score和unconditional score两部分组成。在训练时,我们可以通过一个对标签的Dropout来将标签以一定概率置空,从而实现了两个score在同一个模型中的训练。
和LDM一样,DiT也是一个作用在潜空间上的模型,因此它也采用了一个VQ-VAE将图像编码到潜空间。这里我们主要介绍DiT在潜空间上的扩散过程做的改进,它的结构如图1所示,跟论文一样,我们也是按照DiTs的前向顺序介绍这个图。DiT的具体实现见FAIR的开源代码[8],下面我们结合代码来具体介绍它们。
图1:DiTs的网络结构
首先,我们看一下DiT的forward函数的实现。
- class DiT(nn.Module):
- """
- Diffusion model with a Transformer backbone.
- """
- def __init__(
- self,
- input_size=32,
- patch_size=2,
- in_channels=4,
- hidden_size=1152,
- depth=28,
- num_heads=16,
- mlp_ratio=4.0,
- class_dropout_prob=0.1,
- num_classes=1000,
- learn_sigma=True,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels * 2 if learn_sigma else in_channels
- self.patch_size = patch_size
- self.num_heads = num_heads
-
- self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
- num_patches = self.x_embedder.num_patches
- # Will use fixed sin-cos embedding:
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
-
- self.blocks = nn.ModuleList([
- DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
- ])
- self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
- self.initialize_weights()
-
- def forward(self, x, t, y):
- """
- Forward pass of DiT.
- x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
- t: (N,) tensor of diffusion timesteps
- y: (N,) tensor of class labels
- """
- x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
- t = self.t_embedder(t) # (N, D)
- y = self.y_embedder(y, self.training) # (N, D)
- c = t + y # (N, D)
- for block in self.blocks:
- x = block(x, c) # (N, T, D)
- x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
- x = self.unpatchify(x) # (N, out_channels, H, W)
- return x
x_embedder
是import自timm.models.vision_transformer
,DiT中将其叫做模块化(Patchify);DiTBlock
类和FinalLayer
类,接下来我也回详细介绍。unpatchify
是将一维序列还原为二维潜空间。模块化(Patchify)的作用是将VAE编码的二维特征转化为一维序列。这里有两个细节:
图2:DiT的模块化部分
模块化的实现继承自timm.models.vision_transformer
的PatchEmbed
函数。
from timm.models.vision_transformer import PatchEmbed
DiT模块有两个作用,一个是对特征进行加工,另一个是融合图像的特征和不同模态的条件特征。DiT中探索了四个不同的模块:
如图1.(d)所示,基于上下文条件的DiT直接将条件特征附加到输入序列中,这个操作类似于在输入序列中添加了一个[CLS] token。DiT的条件编码是通过LabelEmbedder
类实现的,具体实现见下面代码片段。
- class LabelEmbedder(nn.Module):
- """
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
- """
- def __init__(self, num_classes, hidden_size, dropout_prob):
- super().__init__()
- use_cfg_embedding = dropout_prob > 0
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
- self.num_classes = num_classes
- self.dropout_prob = dropout_prob
-
- def token_drop(self, labels, force_drop_ids=None):
- """
- Drops labels to enable classifier-free guidance.
- """
- if force_drop_ids is None:
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
- else:
- drop_ids = force_drop_ids == 1
- labels = torch.where(drop_ids, self.num_classes, labels)
- return labels
-
- def forward(self, labels, train, force_drop_ids=None):
- use_dropout = self.dropout_prob > 0
- if (train and use_dropout) or (force_drop_ids is not None):
- labels = self.token_drop(labels, force_drop_ids)
- embeddings = self.embedding_table(labels)
- return embeddings
从上面的代码中我们可以看出,LabelEmbdder
的核心计算是通过一个embedding层对类别标签进行编码。注意DiT对标签进行了dropoout。如第1.1节介绍的,label dropout的作用是为了classifier-free guidance。
如图1.(c)所示,我们将时间片特征t和条件特征c拼成一个长度为2的序列(图1.(a))。然后将这个序列输入到一个多头交叉注意力模块中和图像特征进行融合。关于DiT交叉注意力的具体实现参照我的LDM一文。
DiT在模型中尝试了AdaLN[9],AdaLN的核心思想是使用模型中的一些信息学习 � 和 � 两个归一化参数。DiT是使用时间片特征 � 和条件特征 � 相加后的结果计算这两个参数(也就是第一个代码片段中的变量c
)。此外,DiT在每个残差连接之后还接了一个回归缩放参数 � ,它同样是由变量c
计算得到。接下来我们根据下面的代码片段详细介绍DiT的具体结构。
- from timm.models.vision_transformer import Attention, Mlp
- class DiTBlock(nn.Module):
- """
- A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
- """
- def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
- super().__init__()
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- approx_gelu = lambda: nn.GELU(approximate="tanh")
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
- self.adaLN_modulation = nn.Sequential(
- nn.SiLU(),
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
- x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
- x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
首先我们观察forword
函数的第1行,它使用adaLN_modulation
计算了6个变量shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
,这6个变量分别对应了多头自注意力的LN的归一化参数与缩放参数(图1.(b)的 �1 , �1 , �1 )以及MLP的LN的归一化参数与缩放参数(图1.(b)的 �2 , �2 , �2 )。
forward函数的第二行是计算多头自注意力以及它的LN,它首先计算的是modulate
函数,实现方式如下面代码片段,即相当于使用学习好的\beta和\gamma对LN进行归一化。接下来再计算的注意力模块,计算方式和Transformer相同。最后在通过乘以gate_msa
对注意力计算的结果进行缩放。
- def modulate(x, shift, scale):
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
接下来forward是计算MLP部分,除了把attn
函数换为mlp
函数外,它和第二行基本相同,这里不再赘述。
当对特征加工完之后,我们需要使用FinalLayer
模块来将特征还原为与输入相同的尺寸。它是由一个AdaLN和一个线性层组成,具体实现见下面代码片段。
- class FinalLayer(nn.Module):
- """
- The final layer of DiT.
- """
- def __init__(self, hidden_size, patch_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.SiLU(),
- nn.Linear(hidden_size, 2 * hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
之前有研究表明使用0初始化网络中的某些参数可以加速模型的训练。例如我们可以将残差网络的残差部分初始化为0,这样初始化后的残差块相当于一个单位映射,可以直接将上一层的特征透传给下一层。我们也可以将BN的归一化因子 � 初始化为0来加速模型的训练[10]。DiT对模型参数的初始化都是在initialize_weights
函数中实现的,它的作用是对DiT中的变量进行初始化,我们具体看一下这个函数。
- def initialize_weights(self):
- # Initialize transformer layers:
- def _basic_init(module):
- if isinstance(module, nn.Linear):
- torch.nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- nn.init.constant_(module.bias, 0)
- self.apply(_basic_init)
-
- # Initialize (and freeze) pos_embed by sin-cos embedding:
- pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
-
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # Zero-out adaLN modulation layers in DiT blocks:
- for block in self.blocks:
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
作者对上面四种模块进行了对照实验,并使用了FID(Fréchet inception distance)指标对四个模块进行了效果评估。FID是计算真实图像和相似图像之间距离的的一种度量方式。他根据Inception v3分类模型计算得到的。分数越低则代表两组图像越相似,FID在最佳的情况下值是0,表示两组图完全相同。从实验结果我们可以看出adaLN-Zero还是有比较显著的优势的。
图3:DiT在四个模块上的对照实验
DiT最大的创新点是将Transformer引入到了扩散模型中,并完全抛弃了CNN。但是DiT并不是第一个引入Transformer的,例如之前的U-ViT[11],UniDiffuser[12]等都尝试了将Transformer引入到扩散模型中。至于对效果提升同样非常有帮助的adaLN,zero-初始化,classifier-free guidance等则是已有的工作了。DiT引入条件信息还是仅仅局限在样本类别,接下来我们有必要学习一些引入文本序列作为条件的生成模型了。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。