当前位置:   article > 正文

扩散模型 (Diffusion Model) 简要介绍与源码分析

diffusion model

扩散模型 (Diffusion Model) 简要介绍与源码分析

前言

近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.

于是找到他们的开源产品体验了一把, 在线体验地址在 https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上: https://github.com/CompVis/stable-diffusion), 在搜索框中输入 “A dog flying in the sky” (一只狗在天空飞翔), 生成效果如下:

Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.

当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.

广而告之

可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.

另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.

总览

本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.

参考文章

在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:

扩散模型介绍

基本原理

Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.

具体来说, 前向阶段在原始图像 x 0 \mathbf{x}_0 x0 上逐步增加噪声, 每一步得到的图像 x t \mathbf{x}_t xt 只和上一步的结果 x t − 1 \mathbf{x}_{t - 1} xt1 相关, 直至第 T T T 步的图像 x T \mathbf{x}_T xT 变为纯高斯噪声. 前向阶段图示如下:

而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声 x T \mathbf{x}_T xT, 通过逐步去噪, 直至最终将原图像 x 0 \mathbf{x}_0 x0 给恢复出来, 逆向阶段图示如下:

模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式,

前向阶段

由于前向过程中图像 x t \mathbf{x}_t xt 只和上一时刻的 x t − 1 \mathbf{x}_{t - 1} xt1 有关, 该过程可以视为马尔科夫过程, 满足:

q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) , q(x1:Tx0)=Tt=1q(xtxt1)q(xtxt1)=N(xt;1βtxt1,βtI),

q(x1:Tx0)q(xtxt1)=t=1Tq(xtxt1)=N(xt;1βtxt1,βtI),
q(x1:Tx0)q(xtxt1)=t=1Tq(xtxt1)=N(xt;1βt xt1,βtI),

其中 β t ∈ ( 0 , 1 ) \beta_t\in(0, 1) βt(0,1) 为高斯分布的方差超参, 并满足 β 1 < β 2 < … < β T \beta_1 < \beta_2 < \ldots < \beta_T β1<β2<<βT. 另外公式 (2) 中为何均值 x t − 1 x_{t-1} xt1 前乘上系数 1 − β t x t − 1 \sqrt{1-\beta_t} x_{t-1} 1βt xt1 的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过 重参数技巧 采样得到 x t x_t xt.

重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布 z ∼ N ( z ; μ , σ 2 I ) z \sim \mathcal{N}(z; \mu, \sigma^2\mathbf{I}) zN(z;μ,σ2I) 中采样样本 z z z, 可以通过引入随机变量 ϵ ∼ N ( 0 , I ) \epsilon\sim\mathcal{N}(0, \mathbf{I}) ϵN(0,I), 使得 z = μ + σ ⊙ ϵ z = \mu + \sigma\odot\epsilon z=μ+σϵ, 此时 z z z 依旧具有随机性, 且服从高斯分布 N ( μ , σ 2 I ) \mathcal{N}(\mu, \sigma^2\mathbf{I}) N(μ,σ2I), 同时 μ \mu μ σ \sigma σ (通常由网络生成) 可导.

简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样 x t x_t xt 的方法, 即生成随机变量 ϵ t ∼ N ( 0 , I ) \epsilon_t\sim\mathcal{N}(0, \mathbf{I}) ϵtN(0,I),
然后令 α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt, 以及 α t ‾ = ∏ i = 1 T α t \overline{\alpha_t} = \prod_{i=1}^{T}\alpha_t αt=i=1Tαt, 从而可以得到:

x t = 1 − β t x t − 1 + β t ϵ 1  where     ϵ 1 , ϵ 2 , … ∼ N ( 0 , I ) ,    reparameter trick ; = a t x t − 1 + 1 − α t ϵ 1 = a t ( a t − 1 x t − 2 + 1 − α t − 1 ϵ 2 ) + 1 − α t ϵ 1 = a t a t − 1 x t − 2 + ( a t ( 1 − α t − 1 ) ϵ 2 + 1 − α t ϵ 1 ) = a t a t − 1 x t − 2 + 1 − α t α t − 1 ϵ ˉ 2  where  ϵ ˉ 2 ∼ N ( 0 , I ) ; = … = α ˉ t x 0 + 1 − α ˉ t ϵ ˉ t . xt=1βtxt1+βtϵ1 where ϵ1,ϵ2,N(0,I),reparameter trick;=atxt1+1αtϵ1=at(at1xt2+1αt1ϵ2)+1αtϵ1=atat1xt2+(at(1αt1)ϵ2+1αtϵ1)=atat1xt2+1αtαt1ˉϵ2 where ˉϵ2N(0,I);==ˉαtx0+1ˉαtˉϵt.

xt=1βtxt1+βtϵ1 where ϵ1,ϵ2,N(0,I),reparameter trick;=atxt1+1αtϵ1=at(at1xt2+1αt1ϵ2)+1αtϵ1=atat1xt2+(at(1αt1)ϵ2+1αtϵ1)=atat1xt2+1αtαt1ϵ¯2 where ϵ¯2N(0,I);==α¯tx0+1α¯tϵ¯t.(3-1)(3-2)
xt=1βt xt1+βtϵ1 where ϵ1,ϵ2,N(0,I),reparameter trick;=at xt1+1αt ϵ1=at (at1 xt2+1αt1 ϵ2)+1αt ϵ1=atat1 xt2+(at(1αt1) ϵ2+1αt ϵ1)=atat1 xt2+1αtαt1 ϵˉ2 where ϵˉ2N(0,I);==αˉt x0+1αˉt ϵˉt.(3-1)(3-2)

其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有 N ( 0 , σ 1 2 I ) + N ( 0 , σ 2 2 I ) ∼ N ( 0 , ( σ 1 2 + σ 2 2 ) I ) \mathcal{N}\left(0, \sigma_1^2\mathbf{I}\right) +\mathcal{N}\left(0,\sigma_2^2 \mathbf{I}\right)\sim\mathcal{N}\left(0, \left(\sigma_1^2 + \sigma_2^2\right)\mathbf{I}\right) N(0,σ12I)+N(0,σ22I)N(0,(σ12+σ22)I), 因此:

a t ( 1 − α t − 1 ) ϵ 2 ∼ N ( 0 , a t ( 1 − α t − 1 ) I ) 1 − α t ϵ 1 ∼ N ( 0 , ( 1 − α t ) I ) a t ( 1 − α t − 1 ) ϵ 2 + 1 − α t ϵ 1 ∼ N ( 0 , [ α t ( 1 − α t − 1 ) + ( 1 − α t ) ] I ) = N ( 0 , ( 1 − α t α t − 1 ) I ) . at(1αt1)ϵ2N(0,at(1αt1)I)1αtϵ1N(0,(1αt)I)at(1αt1)ϵ2+1αtϵ1N(0,[αt(1αt1)+(1αt)]I)=N(0,(1αtαt1)I).

at(1αt1)ϵ2N(0,at(1αt1)I)1αtϵ1N(0,(1αt)I)at(1αt1)ϵ2+1αtϵ1N(0,[αt(1αt1)+(1αt)]I)=N(0,(1αtαt1)I).
at(1αt1) ϵ2N(0,at(1αt1)I)1αt ϵ1N(0,(1αt)I)at(1αt1) ϵ2+1αt ϵ1N(0,[αt(1αt1)+(1αt)]I)=N(0,(1αtαt1)I).

注意公式 (3-2) 中 ϵ ˉ 2 ∼ N ( 0 , I ) \bar{\epsilon}_2 \sim \mathcal{N}(0, \mathbf{I}) ϵˉ2N(0,I), 因此还需乘上 1 − α t α t − 1 \sqrt{1-\alpha_t \alpha_{t-1}} 1αtαt1 . 从公式 (3) 可以看出

q ( x t ∣ x 0 ) = N ( x t ; a ˉ t x 0 , ( 1 − a ˉ t ) I ) q(xtx0)=N(xt;ˉatx0,(1ˉat)I)

q(xtx0)=N(xt;aˉt x0,(1aˉt)I)

注意由于 β t ∈ ( 0 , 1 ) \beta_t\in(0, 1) βt(0,1) β 1 < … < β T \beta_1 < \ldots < \beta_T β1<<βT, 而 α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt, 因此 α t ∈ ( 0 , 1 ) \alpha_t\in(0, 1) αt(0,1) 并且有 α 1 > … > α T \alpha_1 > \ldots>\alpha_T α1>>αT, 另外由于 α ˉ t = ∏ i = 1 T α t \bar{\alpha}_t=\prod_{i=1}^T\alpha_t αˉt=i=1Tαt, 因此当 T → ∞ T\rightarrow\infty T 时, α ˉ t → 0 \bar{\alpha}_t\rightarrow0 αˉt0 以及 ( 1 − a ˉ t ) → 1 (1-\bar{a}_t)\rightarrow 1 (1aˉt)1, 此时 x T ∼ N ( 0 , I ) x_T\sim\mathcal{N}(0, \mathbf{I}) xTN(0,I). 从这里的推导来看, 在公式 (2) 中的均值 x t − 1 x_{t-1} xt1 前乘上系数 1 − β t x t − 1 \sqrt{1-\beta_t} x_{t-1} 1βt xt1 会使得 x T x_{T} xT 最后收敛到标准高斯分布.

逆向阶段

前向阶段是加噪声的过程, 而逆向阶段则是将噪声去除, 如果能得到逆向过程的分布 q ( x t − 1 ∣ x t ) q\left(x_{t-1} \mid x_t\right) q(xt1xt), 那么通过输入高斯噪声 x T ∼ N ( 0 , I ) x_T\sim\mathcal{N}(0, \mathbf{I}) xTN(0,I), 我们将生成一个真实的样本. 注意到当 β t \beta_t βt 足够小时, q ( x t − 1 ∣ x t ) q\left(x_{t-1} \mid x_t\right) q(xt1xt) 也是高斯分布, 具体的证明在 ewrfcas 的知乎文章: 由浅入深了解Diffusion Model 推荐的论文中: On the theory of stochastic processes, with particular reference to applications. 我大致看了一下, 哈哈, 没太看明白, 不过想到这个不是我关注的重点, 因此 pass. 由于我们无法直接推断 q ( x t − 1 ∣ x t ) q\left(x_{t-1} \mid x_t\right) q(xt1xt), 因此我们将使用深度学习模型 p θ p_{\theta} pθ 去拟合分布 q ( x t − 1 ∣ x t ) q\left(x_{t-1} \mid x_t\right) q(xt1xt), 模型参数为 θ \theta θ:

p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) pθ(x0:T)=p(xT)Tt=1pθ(xt1xt)pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

pθ(x0:T)pθ(xt1xt)=p(xT)t=1Tpθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

注意到, 虽然我们无法直接求得 q ( x t − 1 ∣ x t ) q\left(x_{t-1} \mid x_t\right) q(xt1xt) (注意这里是 q q q 而不是模型 p θ p_{\theta} pθ), 但在知道 x 0 x_0 x0 的情况下, 可以通过贝叶斯公式得到 q ( x t − 1 ∣ x t , x 0 ) q\left(x_{t-1} \mid x_t, x_0\right) q(xt1xt,x0) 为:

q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ ( x t , x 0 ) , β ~ t I ) q(xt1xt,x0)=N(xt1;˜μ(xt,x0),˜βtI)

q(xt1xt,x0)=N(xt1;μ~(xt,x0),β~tI)

推导过程如下:

q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) ∝ exp ⁡ ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = exp ⁡ ( − 1 2 ( x t 2 − 2 α t x t x t − 1 + α t x t − 1 2 β t + x t − 1 2 − 2 α ˉ t − 1 x 0 x t − 1 + α ˉ t − 1 x 0 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = exp ⁡ ( − 1 2 ( ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 ⏟ x t − 1  方差  − ( 2 α t β t x t + 2 α ˉ t − 1 1 − α ˉ t − 1 x 0 ) x t − 1 ⏟ x t − 1  均值  + C ( x t , x 0 ) ⏟ 与  x t − 1  无关  ) ) q(xt1|xt,x0)=q(xt|xt1,x0)q(xt1|x0)q(xt|x0)exp(12((xtαtxt1)2βt+(xt1ˉαt1x0)21ˉαt1(xtˉαtx0)21ˉαt))=exp(12(x2t2αtxtxt1+αtx2t1βt+x2t12ˉαt1x0xt1+ˉαt1x201ˉαt1(xtˉαtx0)21ˉαt))=exp(12((αtβt+11ˉαt1)x2t1xt1 方差 (2αtβtxt+2ˉαt11ˉαt1x0)xt1xt1 均值 +C(xt,x0)与 xt1 无关 ))

q(xt1xt,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0)exp(21(βt(xtαt xt1)2+1αˉt1(xt1αˉt1 x0)21αˉt(xtαˉt x0)2))=exp(21(βtxt22αt xtxt1+αtxt12+1αˉt1xt122αˉt1 x0xt1+αˉt1x021αˉt(xtαˉt x0)2))=exp(21(xt1 方差  (βtαt+1αˉt11)xt12xt1 均值  (βt2αt xt+1αˉt12αˉt1 x0)xt1+ xt1 无关  C(xt,x0)))

上面推导过程中, 通过贝叶斯公式巧妙的将逆向过程转换为前向过程, 且最终得到的概率密度函数和高斯概率密度函数的指数部分 exp ⁡ ( − ( x − μ ) 2 2 σ 2 ) = exp ⁡ ( − 1 2 ( 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ) ) \exp{\left(-\frac{\left(x - \mu\right)^2}{2\sigma^2}\right)} = \exp{\left(-\frac{1}{2}\left(\frac{1}{\sigma^2}x^2 - \frac{2\mu}{\sigma^2}x + \frac{\mu^2}{\sigma^2}\right)\right)} exp(2σ2(xμ)2)=exp(21(σ21x2σ22μx+σ2μ2)) 能对应, 即有:

β ~ t = 1 / ( α t β t + 1 1 − α ˉ t − 1 ) = 1 / ( α t − α ˉ t + β t β t ( 1 − α ˉ t − 1 ) ) = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t μ ~ t ( x t , x 0 ) = ( α t β t x t + α ˉ t − 1 1 − α ˉ t − 1 x 0 ) / ( α t β t + 1 1 − α ˉ t − 1 ) = ( α t β t x t + α ˉ t − 1 1 − α ˉ t − 1 x 0 ) 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 ˜βt=1/(αtβt+11ˉαt1)=1/(αtˉαt+βtβt(1ˉαt1))=1ˉαt11ˉαtβt˜μt(xt,x0)=(αtβtxt+ˉαt11ˉαt1x0)/(αtβt+11ˉαt1)=(αtβtxt+ˉαt11ˉαt1x0)1ˉαt11ˉαtβt=αt(1ˉαt1)1ˉαtxt+ˉαt1βt1ˉαtx0

β~tμ~t(xt,x0)=1/(βtαt+1αˉt11)=1/(βt(1αˉt1)αtαˉt+βt)=1αˉt1αˉt1βt=(βtαt xt+1αˉt1αˉt1 x0)/(βtαt+1αˉt11)=(βtαt xt+1αˉt1αˉt1 x0)1αˉt1αˉt1βt=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0

通过公式 (8) 和公式 (9), 我们能得到 q ( x t − 1 ∣ x t , x 0 ) q\left(x_{t-1} \mid x_t, x_0\right) q(xt1xt,x0) (见公式 (7)) 的分布. 此外由于公式 (3) 揭示的 x t x_t xt x 0 x_0 x0 之间的关系: x t = α ˉ t x 0 + 1 − α ˉ t ϵ ˉ t x_t =\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \bar{\epsilon}_t xt=αˉt x0+1αˉt ϵˉt, 可以得到

x 0 = 1 α ˉ t ( x t − 1 − α ˉ t ϵ t ) x0=1ˉαt(xt1ˉαtϵt)

x0=αˉt 1(xt1αˉt ϵt)

代入公式 (9) 中得到:

μ ~ t = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t 1 α ˉ t ( x t − 1 − α ˉ t ϵ t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) ˜μt=αt(1ˉαt1)1ˉαtxt+ˉαt1βt1ˉαt1ˉαt(xt1ˉαtϵt)=1αt(xt1αt1ˉαtϵt)

μ~t=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtαˉt 1(xt1αˉt ϵt)=αt 1(xt1αˉt 1αtϵt)

补充一下公式 (11) 的详细推导过程:

前面说到, 我们将使用深度学习模型 p θ p_{\theta} pθ 去拟合逆向过程的分布 q ( x t − 1 ∣ x t ) q\left(x_{t-1} \mid x_t\right) q(xt1xt), 由公式 (6) 知 p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta\left(x_{t-1} \mid x_t\right) =\mathcal{N}\left(x_{t-1} ; \mu_\theta\left(x_t, t\right), \Sigma_\theta\left(x_t, t\right)\right) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t)), 我们希望训练模型 μ θ ( x t , t ) \mu_\theta\left(x_t, t\right) μθ(xt,t) 以预估 μ ~ t = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) \tilde{\mu}_t = \frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_t \Big) μ~t=αt 1(xt1αˉt 1αtϵt). 由于 x t x_t xt 在训练阶段会作为输入, 因此它是已知的, 我们可以转而让模型去预估噪声 ϵ t \epsilon_t ϵt, 即令:

μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) Thus  x t − 1 = N ( x t − 1 ; 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) , Σ θ ( x t , t ) ) μθ(xt,t)=1αt(xt1αt1ˉαtϵθ(xt,t))Thus xt1=N(xt1;1αt(xt1αt1ˉαtϵθ(xt,t)),Σθ(xt,t))

μθ(xt,t)Thus xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))=N(xt1;αt 1(xt1αˉt 1αtϵθ(xt,t)),Σθ(xt,t))

模型训练

前面谈到, 逆向阶段让模型去预估噪声 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t), 那么应该如何设计 Loss 函数 ? 我们的目标是在真实数据分布下, 最大化模型预测分布的对数似然, 即优化在 x 0 ∼ q ( x 0 ) x_0\sim q(x_0) x0q(x0) 下的 p θ ( x 0 ) p_\theta(x_0) pθ(x0) 交叉熵:

L = E q ( x 0 ) [ − log ⁡ p θ ( x 0 ) ] L=Eq(x0)[logpθ(x0)]

L=Eq(x0)[logpθ(x0)]

变分自动编码器 VAE 类似, 使用 Variational Lower Bound 来优化: − log ⁡ p θ ( x 0 ) -\log{p_\theta(x_0)} logpθ(x0) :

− log ⁡ p θ ( x 0 ) ≤ − log ⁡ p θ ( x 0 ) + D K L ( q ( x 1 : T ∣ x 0 ) ∥ p θ ( x 1 : T ∣ x 0 ) ) ; 注: 注意KL散度非负 = − log ⁡ p θ ( x 0 ) + E q ( x 1 : T ∣ x 0 ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) / p θ ( x 0 ) ] ;     where     p θ ( x 1 : T ∣ x 0 ) = p θ ( x 0 : T ) p θ ( x 0 ) = − log ⁡ p θ ( x 0 ) + E q ( x 1 : T ∣ x 0 ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) + log ⁡ p θ ( x 0 ) ⏟ 与q无关  ] = E q ( x 1 : T ∣ x 0 ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] . logpθ(x0)logpθ(x0)+DKL(q(x1:Tx0)pθ(x1:Tx0));注: 注意KL散度非负=logpθ(x0)+Eq(x1:Tx0)[logq(x1:Tx0)pθ(x0:T)/pθ(x0)]; where pθ(x1:Tx0)=pθ(x0:T)pθ(x0)=logpθ(x0)+Eq(x1:Tx0)[logq(x1:Tx0)pθ(x0:T)+logpθ(x0)与q无关 ]=Eq(x1:Tx0)[logq(x1:Tx0)pθ(x0:T)].

logpθ(x0)logpθ(x0)+DKL(q(x1:Tx0)pθ(x1:Tx0));注意KL散度非负=logpθ(x0)+Eq(x1:Tx0)[logpθ(x0:T)/pθ(x0)q(x1:Tx0)]; where pθ(x1:Tx0)=pθ(x0)pθ(x0:T)=logpθ(x0)+Eq(x1:Tx0)[logpθ(x0:T)q(x1:Tx0)+q无关  logpθ(x0)]=Eq(x1:Tx0)[logpθ(x0:T)q(x1:Tx0)].

对公式 (15) 左右两边取期望 E q ( x 0 ) \mathbb{E}_{q(x_0)} Eq(x0), 利用到重积分中的 Fubini 定理 可得:

L V L B = E q ( x 0 ) ( E q ( x 1 : T ∣ x 0 ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] ) = E q ( x 0 : T ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] ⏟ Fubini定理  ≥ E q ( x 0 ) [ − log ⁡ p θ ( x 0 ) ] \mathcal{L}_{V L B}=\underbrace{\mathbb{E}_{q\left(x_0\right)}\left(\mathbb{E}_{q\left(x_{1: T} \mid x_0\right)}\left[\log \frac{q\left(x_{1: T} \mid x_0\right)}{p_\theta\left(x_{0: T}\right)}\right]\right)=\mathbb{E}_{q\left(x_{0: T}\right)}\left[\log \frac{q\left(x_{1: T} \mid x_0\right)}{p_\theta\left(x_{0: T}\right)}\right]}_{\text {Fubini定理 }} \geq \mathbb{E}_{q\left(x_0\right)}\left[-\log p_\theta\left(x_0\right)\right] LVLB=Fubini定理  Eq(x0)(Eq(x1:Tx0)[logpθ(x0:T)q(x1:Tx0)])=Eq(x0:T)[logpθ(x0:T)q(x1:Tx0)]Eq(x0)[logpθ(x0)]

因此最小化 L V L B \mathcal{L}_{V L B} LVLB 就可以优化公式 (14) 中的目标函数. 之后对 L V L B \mathcal{L}_{V L B} LVLB 做进一步的推导, 这部分的详细推导见上面的参考文章, 最终的结论是:

L V L B = L T + L T − 1 + … + L 0 L T = D K L ( q ( x T ∣ x 0 ) ∣ ∣ p θ ( x T ) ) L t = D K L ( q ( x t ∣ x t + 1 , x 0 ) ∣ ∣ p θ ( x t ∣ x t + 1 ) ) ; 1 ≤ t ≤ T − 1 L 0 = − log ⁡ p θ ( x 0 ∣ x 1 ) LVLB=LT+LT1++L0LT=DKL(q(xT|x0)||pθ(xT))Lt=DKL(q(xt|xt+1,x0)||pθ(xt|xt+1));1tT1L0=logpθ(x0|x1)

LVLBLTLtL0=LT+LT1++L0=DKL(q(xTx0)∣∣pθ(xT))=DKL(q(xtxt+1,x0)∣∣pθ(xtxt+1));1tT1=logpθ(x0x1)

最终是优化两个高斯分布 q ( x t ∣ x t − 1 , x 0 ) = N ( x t − 1 ; μ ~ ( x t , x 0 ) , β ~ t I ) q(x_t|x_{t - 1}, x_0) = \mathcal{N}\left(x_{t-1} ; {\color{blue}{\tilde{\mu}}(x_t, x_0)}, {\color{red}{\tilde{\beta}_t} \mathbf{I}}\right) q(xtxt1,x0)=N(xt1;μ~(xt,x0),β~tI) (详见公式 (7)) 与 p θ ( x t ∣ x t + 1 ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ) p_{\theta}(x_t|x_{t+1}) = \mathcal{N}\left(x_{t-1} ; \mu_\theta\left(x_t, t\right), \Sigma_\theta\right) pθ(xtxt+1)=N(xt1;μθ(xt,t),Σθ) (详见公式(6), 此为模型预估的分布)之间的 KL 散度. 由于多元高斯分布的 KL 散度存在闭式解, 详见: Multivariate_normal_distributions, 从而可以得到:

L t = E x 0 , ϵ [ 1 2 ∥ Σ θ ( x t , t ) ∥ 2 2 ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] = E x 0 , ϵ [ 1 2 ∥ Σ θ ∥ 2 2 ∥ 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) − 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) ∥ 2 ] = E x 0 , ϵ [ ( 1 − α t ) 2 2 α t ( 1 − α ˉ t ) ∥ Σ θ ∥ 2 2 ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] ; 其中 ϵ t 为高斯噪声 , ϵ θ 为模型学习的噪声 = E x 0 , ϵ [ ( 1 − α t ) 2 2 α t ( 1 − α ˉ t ) ∥ Σ θ ∥ 2 2 ∥ ϵ t − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) ∥ 2 ] Lt=Ex0,ϵ[12Σθ(xt,t)22˜μt(xt,x0)μθ(xt,t)2]=Ex0,ϵ[12Σθ221αt(xt1αt1ˉαtϵt)1αt(xt1αt1ˉαtϵθ(xt,t))2]=Ex0,ϵ[(1αt)22αt(1ˉαt)Σθ22ϵtϵθ(xt,t)2];其中ϵt为高斯噪声,ϵθ为模型学习的噪声=Ex0,ϵ[(1αt)22αt(1ˉαt)Σθ22ϵtϵθ(ˉαtx0+1ˉαtϵt,t)2]

Lt=Ex0,ϵ[2∥Σθ(xt,t)221μ~t(xt,x0)μθ(xt,t)2]=Ex0,ϵ[2∥Σθ221αt 1(xt1αˉt 1αtϵt)αt 1(xt1αˉt 1αtϵθ(xt,t))2]=Ex0,ϵ[2αt(1αˉt)Σθ22(1αt)2ϵtϵθ(xt,t)2];其中ϵt为高斯噪声,ϵθ为模型学习的噪声=Ex0,ϵ[2αt(1αˉt)Σθ22(1αt)2ϵtϵθ(αˉt x0+1αˉt ϵt,t)2]

DDPM 将 Loss 简化为如下形式:

L t simple  = E x 0 , ϵ t [ ∥ ϵ t − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) ∥ 2 ] Lsimple t=Ex0,ϵt[ϵtϵθ(ˉαtx0+1ˉαtϵt,t)2]

Ltsimple =Ex0,ϵt[ ϵtϵθ(αˉt x0+1αˉt ϵt,t) 2]

因此 Diffusion 模型的目标函数即是学习高斯噪声 ϵ t \epsilon_t ϵt ϵ θ \epsilon_{\theta} ϵθ (来自模型输出) 之间的 MSE loss.

最终算法

最终 DDPM 的算法流程如下:

训练阶段重复如下步骤:

  • 从数据集中采样 x 0 x_0 x0
  • 随机选取 time step t t t
  • 生成高斯噪声 ϵ t ∈ N ( 0 , I ) \epsilon_t\in\mathcal{N}(0, \mathbf{I}) ϵtN(0,I)
  • 调用模型预估 ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) \epsilon_\theta\left(\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t, t\right) ϵθ(αˉt x0+1αˉt ϵt,t)
  • 计算噪声之间的 MSE Loss: ∥ ϵ t − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) ∥ 2 \left\|\epsilon_t-\epsilon_\theta\left(\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t, t\right)\right\|^2 ϵtϵθ(αˉt x0+1αˉt ϵt,t) 2, 并利用反向传播算法训练模型.

逆向阶段采用如下步骤进行采样:

  • 从高斯分布采样 x T x_T xT
  • 按照 T , … , 1 T, \ldots, 1 T,,1 的顺序进行迭代:
    • 如果 t = 1 t = 1 t=1, 令 z = 0 \mathbf{z} = {0} z=0; 如果 t > 1 t > 1 t>1, 从高斯分布中采样 z ∼ N ( 0 , I ) \mathbf{z}\sim\mathcal{N}(0, \mathbf{I}) zN(0,I)
    • 利用公式 (12) 学习出均值 μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta(x_t, t) = \color{cyan}{\frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \Big)} μθ(xt,t)=αt 1(xt1αˉt 1αtϵθ(xt,t)), 并利用公式 (8) 计算均方差 σ t = β ~ t = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t \sigma_t = \sqrt{\tilde{\beta}_t} = \sqrt{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t} σt=β~t =1αˉt1αˉt1βt
    • 通过重参数技巧采样 x t − 1 = μ θ ( x t , t ) + σ t z x_{t - 1} = \mu_\theta(x_t, t) + \sigma_t\mathbf{z} xt1=μθ(xt,t)+σtz
  • 经过以上过程的迭代, 最终恢复 x 0 x_0 x0.

源码分析

DDPM 文章以及代码的相关信息如下:

本文以分析 Tensorflow 源码为主, Pytorch 版本的代码和 Tensorflow 版本的实现逻辑大体不差的, 变量名字啥的都类似, 阅读起来不会有啥门槛. Tensorlow 源码对 Diffusion 模型的实现位于 diffusion_utils_2.py, 模型本身的分析以该文件为主.

训练阶段

以 CIFAR 数据集为例.

run_cifar.py 中进行前向传播计算 Loss:

  • 第 6 行随机选出 t ∼ Uniform ( { 1 , … , T } ) t\sim\text{Uniform}(\{1, \ldots, T\}) tUniform({1,,T})
  • 第 7 行 training_losses 定义在 GaussianDiffusion2 中, 计算噪声间的 MSE Loss.

进入 GaussianDiffusion2 中, 看到初始化函数中定义了诸多变量, 我在注释中使用公式的方式进行了说明:

下面进入到 training_losses 函数中:

  • 第 19 行: self.model_mean_type 默认是 eps, 模型学习的是噪声, 因此 target 是第 6 行定义的 noise, 即 ϵ t \epsilon_t ϵt
  • 第 9 行: 调用 self.q_sample 计算 x t x_t xt, 即公式 (3) x t = α ˉ t x 0 + 1 − α ˉ t ϵ t x_t =\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t xt=αˉt x0+1αˉt ϵt
  • 第 21 行: denoise_fn 是定义在 unet.py 中的 UNet 模型, 只需知道它的输入和输出大小相同; 结合第 9 行得到的 x t x_t xt, 得到模型预估的噪声: ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) \epsilon_\theta\left(\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t, t\right) ϵθ(αˉt x0+1αˉt ϵt,t)
  • 第 23 行: 计算两个噪声之间的 MSE: ∥ ϵ t − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) ∥ 2 \left\|\epsilon_t-\epsilon_\theta\left(\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \epsilon_t, t\right)\right\|^2 ϵtϵθ(αˉt x0+1αˉt ϵt,t) 2, 并利用反向传播算法训练模型

上面第 9 行定义的 self.q_sample 详情如下:

  • 第 13 行的 q_sample 已经介绍过, 不多说.
  • 第 2 行的 _extract 在代码中经常被使用到, 看到它只需知道它是用来提取系数的即可. 引入输入是一个 Batch, 里面的每个样本都会随机采样一个 time step t t t, 因此需要使用 tf.gather 来将 α t ˉ \bar{\alpha_t} αtˉ 之类选出来, 然后将系数 reshape 为 [B, 1, 1, ....] 的形式, 目的是为了利用 broadcasting 机制和 x t x_t xt 这个 Tensor 相乘.

前向的训练阶段代码实现非常简单, 下面看逆向阶段

逆向阶段

逆向阶段代码定义在 GaussianDiffusion2 中:

  • 第 5 行生成高斯噪声 x T x_T xT, 然后对其不断去噪直至恢复原始图像
  • 第 11 行的 self.p_sample 就是公式 (6) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta\left(x_{t-1} \mid x_t\right) =\mathcal{N}\left(x_{t-1} ; \mu_\theta\left(x_t, t\right), \Sigma_\theta\left(x_t, t\right)\right) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t)) 的过程, 使用模型来预估 μ θ ( x t , t ) \mu_\theta\left(x_t, t\right) μθ(xt,t) 以及 Σ θ ( x t , t ) \Sigma_\theta\left(x_t, t\right) Σθ(xt,t)
  • 第 12 行的 denoise_fn 在前面说过, 是定义在 unet.py 中的 UNet 模型; img_ 表示 x t x_t xt.
  • 第 13 行的 noise_fn 则默认是 tf.random_normal, 用于生成高斯噪声.

进入 p_sample 函数:

  • 第 7 行调用 self.p_mean_variance 生成 μ θ ( x t , t ) \mu_\theta\left(x_t, t\right) μθ(xt,t) 以及 log ⁡ ( Σ θ ( x t , t ) ) \log\left(\Sigma_\theta\left(x_t, t\right)\right) log(Σθ(xt,t)), 其中 Σ θ ( x t , t ) \Sigma_\theta\left(x_t, t\right) Σθ(xt,t) 通过计算 β ~ t \tilde{\beta}_t β~t 得到.
  • 第 11 行从高斯分布中采样 z \mathbf{z} z
  • 第 18 行通过重参数技巧采样 x t − 1 = μ θ ( x t , t ) + σ t z x_{t - 1} = \mu_\theta(x_t, t) + \sigma_t\mathbf{z} xt1=μθ(xt,t)+σtz, 其中 σ t = β ~ t \sigma_t = \sqrt{\tilde{\beta}_t} σt=β~t

进入 self.p_mean_variance 函数:

  • 第 6 行调用模型 denoise_fn, 通过输入 x t x_t xt, 输出得到噪声 ϵ t \epsilon_t ϵt
  • 第 19 行 self.model_var_type 默认为 fixedlarge, 但我当时看 fixedsmall 比较爽, 因此 model_variancemodel_log_variance 分别为 β ~ t = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t \tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t β~t=1αˉt1αˉt1βt (见公式 8), 以及 log ⁡ β ~ t \log\tilde{\beta}_t logβ~t
  • 第 29 行调用 self._predict_xstart_from_eps 函数, 利用公式 (10) 得到 x 0 = 1 α ˉ t ( x t − 1 − α ˉ t ϵ t ) x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1 - \bar{\alpha}_t}\epsilon_t) x0=αˉt 1(xt1αˉt ϵt)
  • 第 30 行调用 self.q_posterior_mean_variance 通过公式 (9) 得到 μ θ ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \mu_\theta(x_t, x_0) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} x_0 μθ(xt,x0)=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0

self._predict_xstart_from_eps 函数相亲如下:

  • 该函数计算 x 0 = 1 α ˉ t ( x t − 1 − α ˉ t ϵ t ) x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(x_t - \sqrt{1 - \bar{\alpha}_t}\epsilon_t) x0=αˉt 1(xt1αˉt ϵt)

self.q_posterior_mean_variance 函数详情如下:

  • 相关说明见注释, 另外发现对于 μ θ ( x t , x 0 ) \mu_\theta(x_t, x_0) μθ(xt,x0) 的计算使用的是公式 (9) μ θ ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \mu_\theta(x_t, x_0) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} x_0 μθ(xt,x0)=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0 而不是进一步推导后的公式 (11) μ θ ( x t , x 0 ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) \mu_\theta(x_t, x_0) = \frac{1}{\sqrt{\alpha_t}} \Big( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_t \Big) μθ(xt,x0)=αt 1(xt1αˉt 1αtϵt).

总结

写文章真的挺累的, 好处是, 我发现写之前我以为理解了, 但写的过程中又发现有些地方理解的不对. 写完后才终于把逻辑理顺.

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号