当前位置:   article > 正文

diffusion model(四)文生图diffusion model(classifier-free guided)_classifier-free diffusion model

classifier-free diffusion model

系列阅读

文生图diffusion model(classifier-free guided)

背景

在classifier-guided这篇博客我们提到对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,classifier-guided通过额外训练一个分类器来不断矫正每一个时间步的生成图片,最终实现特定类别图片的生成。

Classifier-free的核心思路是:我们无需训练额外的分类器,直接训练带类别信息的噪声预测模型来实现特定类别图片的生成,即 ϵ θ ( x t , t ) → ϵ ^ θ ( x t , y , t ) \epsilon_{\theta}(x_t, t) \rightarrow \hat{\epsilon}_{\theta}(x_t, y, t) ϵθ(xt,t)ϵ^θ(xt,y,t)。从而简化整体的pipeline。

此外,classifier-free方法不局限于类别信息的融入,它还能实现将语义信息融入到diffusion model中,实现更为灵活的文生图。这用classifier-guide是很难做到的。目前的很多工作如DALLE,Stable Diffusion, Imagen等都是Classifier-free形式。如:

在这里插入图片描述

下面我们来看他是怎么做的吧!

方法大意

classifier-free diffusion的实现非常简单。下面对比普通的diffusion model,classifier-guided与classifier-free三种方式的差异。

模型训练目标实现功能训练数据
DM (DDPM, DDIM) ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)从服从高斯分布的噪声中生成图片图片
classifier-guided DM ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)和分类器 p ( y ∣ x t ) p(y|x_t) p(yxt)从服从高斯分布的噪声中生成特定类别的图片DM:图片 分类器:图片-标签对
classifier-free DM ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t), ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)从服从高斯分布的噪声中生成符合文本描述的图片图片-文本对
  • 对于训练 ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)来估计 x t x_t xt在时间 t t t上添加的噪声,再根据采样公式推出 x t − 1 x_{t-1} xt1,从而实现图片生成。训练数据只需要准备图片即可。
  • 对于classifier-guided DM是在普通DM的基础上,额外再训练一个Classifier来获得当前时间步生成的图片类别概率分布,从而实现特定类别的图片生成。
  • 对于classifier-free DM将类别信息(或语义信息)集成到diffusion model的训练过程中,训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t) ϵ θ ( x t , y = ∅ , t ) ( 即 ϵ θ ( x t , t ) ) \epsilon_{\theta}(x_t, y=\empty,t)(\text{即}\epsilon_{\theta}(x_t,t)) ϵθ(xt,y=,t)(ϵθ(xt,t))。训练的时候也会加入无类别信息(或语义信息)的图片进行训练。

回答3个问题深入理解classifier-free DM

  1. 模型如何融入类别信息(或语义信息)
  2. 如何训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t) ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=,t)
  3. 如何进行采样生成

模型如何融入类别信息(或语义信息)

采用交叉注意力机制融入

我们知道,深度学习模型推理的本质可以理解为一系列的数值计算,因此将类别信息(或语义信息)融入到模型中需要预先将其转化为数值。转化的方法有很多,如可以用一个embedding layer。也可以用NLP模型,如Bert、T5、CLIP的text encoder等将类别信息(或语义信息)转化为数值向量,一般称为text embedding。随后需要将text embedding和原本模型中的image representation进行融合。最为常见且有效的方法是用交叉注意力机制CrossAttention。具体来说就是将text embedding作为注意力机制中的keyvalue,原始的图片表征作为query。大家熟知的Stable Diffusion用的就是这个融入方法。交叉注意力机制融入语义信息的本质是spatial-wise attention。

在这里插入图片描述

class SpatialCrossAttention(nn.Module):
  
    def __init__(self, dim, context_dim, heads=4, dim_head=32) -> None:
        super(SpatialCrossAttention, self).__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.proj_in = nn.Conv2d(dim, context_dim, kernel_size=1, stride=1, padding=0)
        self.to_q = nn.Linear(context_dim, hidden_dim, bias=False)
        self.to_k = nn.Linear(context_dim, hidden_dim, bias=False)
        self.to_v = nn.Linear(context_dim, hidden_dim, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x, context=None):
        
        x_q = self.proj_in(x) 
        b, c, h, w = x_q.shape
        x_q = rearrange(x_q, "b c h w -> b (h w) c")
        if context is None:
            context = x_q
        if context.ndim == 2:
            context = rearrange(context, "b c -> b () c")
        q = self.to_q(x_q)
        k = self.to_k(context)
        v = self.to_v(context)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v))
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        
        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=self.heads)
        out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w)
        out = self.to_out(out)
        return out 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
基于channel-wise attention融入

该融入方法与time-embedding的融入方法相同,在时间中往往会预先和time-embedding进行融合,再融入到图片特征中,伪代码如下:

# mixture time-embedding and label embedding
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
    assert y.shape == (x.shape[0],)
    emb = emb + self.label_emb(y)
while len(emb_out.shape) < len(h.shape):
    emb_out = emb_out[..., None]
emb_out = self.emb_layers(emb).type(h.dtype)  # h is image feature
scale, shift = th.chunk(emb_out, 2, dim=1)  # One half of the embedding is used for scaling and the other half for offset
h = h * (1 + scale) + shift  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

基于channel-wise的融入粒度没有CrossAttention细。一般适用类别数量有限的特征融入,如时间embedding,类别embedding。而语义信息的融入更推荐上面CrossAttention的方法。

在这里插入图片描述

如何训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t) ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\emptyset,t) ϵθ(xt,y=,t)

ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t)的训练需要图文对,但互联网上具备文本描述的图片只是浩如烟海的图片海洋中的一小部分。仅用具备图文对数据训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t)将会大大束缚DM的生成多样性。另外,为了使得模型更好的捕获图文的联系 ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=,t)的数据不宜过多,否则模型生成结果的保真度会降低。反之,若 ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=,t)数据过少,将会影响生成结果的多样性。需要根据实际的场景进行调整。

有两个实践中的trick需要注意

  • 在实践中,为了统一 y = ∅ y=\empty y= y ≠ ∅ y \neq \empty y=两种情形,通常会给定一个 y = ∅ y=\empty y=的embedding(可以随机初始化,也可以人为给定),来统一两种情形的建模。
  • 即使所有的数据都有图片对也没有关系,只需在每一个batch中随机将某些数据的标签编码替换为 y = ∅ y=\empty y=的embedding即可。另外

如何进行采样生成

classifier-free diffusion的采样生成过程与前面介绍的DDPM,DDIM类似。唯一有所区别的是将原本的 ϵ ( x t , t ) \epsilon(x_t, t) ϵ(xt,t)用下式代替。
ϵ ^ θ ( x t , y , t ) = ϵ θ ( x t , y = ∅ , t ) + s [ ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ]

(1)ϵ^θ(xt,y,t)=ϵθ(xt,y=\empty,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=\empty,t)]
ϵ^θ(xt,y,t)=ϵθ(xt,y=,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=,t)](1)

下面给出详细的推导过程:

首先根据贝叶斯公式有
p ( y ∣ x t ) = p ( x t ∣ y ) p ( y ) ⏞ 先验分布 p ( x t ) ⇒ p ( y ∣ x t ) ∝ p ( x t ∣ y ) / p ( x t ) ⇒ 取对数 log ⁡ p ( y ∣ x t ) = log ⁡ p ( x t ∣ y ) − log ⁡ p ( x t ) ⇒ 对 x t 求导 ∇ x t log ⁡ p ( y ∣ x t ) = ∇ x t log ⁡ p ( x t ∣ y ) − ∇ x t log ⁡ p ( x t ) ⇒ 根据score function ∇ x t log ⁡ p θ ( x t ) = − 1 1 − α ‾ t ϵ θ ( x t ) ∇ x t log ⁡ p ( y ∣ x t ) = − 1 1 − α ‾ t ( ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ) (2)

p(y|xt)=p(xt|y)p(y)先验分布p(xt)p(y|xt)p(xt|y)/p(xt)logp(y|xt)=logp(xt|y)logp(xt)xtxtlogp(y|xt)=xtlogp(xt|y)xtlogp(xt)根据score functionxtlogpθ(xt)=11α¯tϵθ(xt)xtlogp(y|xt)=11α¯t(ϵθ(xt,y,t)ϵθ(xt,y=\empty,t))
\tag{2} p(yxt)p(yxt)取对数logp(yxt)xt求导xtlogp(yxt)根据score functionxtlogpθ(xt)=1αt 1ϵθ(xt)xtlogp(yxt)=p(xt)p(xty)p(y) 先验分布p(xty)/p(xt)=logp(xty)logp(xt)=xtlogp(xty)xtlogp(xt)=1αt 1(ϵθ(xt,y,t)ϵθ(xt,y=,t))(2)
当我们得到 ∇ x t log ⁡ p ( y ∣ x t ) \nabla_{x_t}\log{p (y| x_t)} xtlogp(yxt),参考classifier-guided的式(17)
ϵ ^ ( x t ∣ y ) ⏟ 本文中的 ϵ ^ θ ( x t , y , t ) : = ϵ θ ( x t ) ⏟ 本文中的 ϵ θ ( x t , y = ∅ , t ) − s 1 − α ‾ t ∇ x t log ⁡ p ϕ ( y ∣ x t ) (3) \underbrace{\hat{\epsilon}(x_t|y)}_{\text{本文中的}\hat{\epsilon}_{\theta}(x_t, y, t)} := \underbrace{\epsilon_\theta(x_t)}_{\text{本文中的}\epsilon_{\theta}(x_t, y=\empty, t)} - s\sqrt{1 - \overline{\alpha}_t}\nabla_{x_t} \log{p_\phi(y|x_t)} \tag{3} 本文中的ϵ^θ(xt,y,t) ϵ^(xty):=本文中的ϵθ(xt,y=,t) ϵθ(xt)s1αt xtlogpϕ(yxt)(3)
可得
ϵ ^ θ ( x t , y , t ) = ϵ θ ( x t , y = ∅ , t ) + s [ ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ]
(4)ϵ^θ(xt,y,t)=ϵθ(xt,y=\empty,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=\empty,t)]
ϵ^θ(xt,y,t)=ϵθ(xt,y=,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=,t)](4)

后面的采样过程与之前的方式一致。

结语

本文详细介绍了classifier-free的提出背景与具体实现方案。它是后续一系列如stable diffusion,DALLE等文生图工作的基石。

参考文献

[1]: Classifier-Free Diffusion Guidance
[2]: GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/633535
推荐阅读
相关标签
  

闽ICP备14008679号