赞
踩
推荐 原创
珍妮的选择2023-03-15 22:09:01博主文章分类:机器学习©著作权
文章标签扩散模型DDPMStable-Diffusion深度学习计算机视觉文章分类计算机视觉人工智能yyds干货盘点阅读数1394
近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.
于是找到他们的开源产品体验了一把, 在线体验地址在 https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上: https://github.com/CompVis/stable-diffusion), 在搜索框中输入 “A dog flying in the sky” (一只狗在天空飞翔), 生成效果如下:
Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.
当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.
可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.
另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.
本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.
在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:
Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.
具体来说, 前向阶段在原始图像 �0x0 上逐步增加噪声, 每一步得到的图像 ��xt 只和上一步的结果 ��−1xt−1 相关, 直至第 �T 步的图像 ��xT 变为纯高斯噪声. 前向阶段图示如下:
而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声 ��xT, 通过逐步去噪, 直至最终将原图像 �0x0 给恢复出来, 逆向阶段图示如下:
模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式,
由于前向过程中图像 ��xt 只和上一时刻的 ��−1xt−1 有关, 该过程可以视为马尔科夫过程, 满足:
�(�1:�∣�0)=∏�=1��(��∣��−1)�(��∣��−1)=�(��;1−����−1,���),q(x1:T∣x0)q(xt∣xt−1)=t=1∏Tq(xt∣xt−1)=N(xt;1−βtxt−1,βtI),
其中 ��∈(0,1)βt∈(0,1) 为高斯分布的方差超参, 并满足 �1<�2<…<��β1<β2<…<βT. 另外公式 (2) 中为何均值 ��−1xt−1 前乘上系数 1−����−11−βtxt−1 的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过 重参数技巧 采样得到 ��xt.
重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布 �∼�(�;�,�2�)z∼N(z;μ,σ2I) 中采样样本 �z, 可以通过引入随机变量 �∼�(0,�)ϵ∼N(0,I), 使得 �=�+�⊙�z=μ+σ⊙ϵ, 此时 �z 依旧具有随机性, 且服从高斯分布 �(�,�2�)N(μ,σ2I), 同时 �μ 与 �σ (通常由网络生成) 可导.
简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样 ��xt 的方法, 即生成随机变量 ��∼�(0,�)ϵt∼N(0,I),
然后令 ��=1−��αt=1−βt, 以及 ��‾=∏�=1���αt=∏i=1Tαt, 从而可以得到:
其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有 �(0,�12�)+�(0,�22�)∼�(0,(�12+�22)�)N(0,σ12I)+N(0,σ22I)∼N(0,(σ12+σ22)I), 因此:
��(1−��−1)�2∼�(0,��(1−��−1)�)1−���1∼�(0,(1−��)�)��(1−��−1)�2+1−���1∼�(0,[��(1−��−1)+(1−��)]�)=�(0,(1−����−1)�).at(1−αt−1)ϵ2∼N(0,at(1−αt−1)I)1−αtϵ1∼N(0,(1−αt)I)at(1−αt−1)ϵ2+1−αtϵ1∼N(0,[αt(1−αt−1)+(1−αt)]I)=N(0,(1−αtαt−1)I).
注意公式 (3-2) 中 �ˉ2∼�(0,�)ϵˉ2∼N(0,I), 因此还需乘上 1−����−11−αtαt−1. 从公式 (3) 可以看出
�(��∣�0)=�(��;�ˉ��0,(1−�ˉ�)�)q(xt∣x0)=N(xt;aˉtx0,(1−aˉt)I)
注意由于 ��∈(0,1)βt∈(0,1) 且 �1<…<��β1<…<βT, 而 ��=1−��αt=1−βt, 因此 ��∈(0,1)αt∈(0,1) 并且有 �1>…>��α1>…>αT, 另外由于 �ˉ�=∏�=1���αˉt=∏i=1Tαt, 因此当 �→∞T→∞ 时, �ˉ�→0αˉt→0 以及 (1−�ˉ�)→1(1−aˉt)→1, 此时 ��∼�(0,�)xT∼N(0,I). 从这里的推导来看, 在公式 (2) 中的均值 ��−1xt−1 前乘上系数 1−����−11−βtxt−1 会使得 ��xT 最后收敛到标准高斯分布.
前向阶段是加噪声的过程, 而逆向阶段则是将噪声去除, 如果能得到逆向过程的分布 �(��−1∣��)q(xt−1∣xt), 那么通过输入高斯噪声 ��∼�(0,�)xT∼N(0,I), 我们将生成一个真实的样本. 注意到当 ��βt 足够小时, �(��−1∣��)q(xt−1∣xt) 也是高斯分布, 具体的证明在 ewrfcas 的知乎文章: 由浅入深了解Diffusion Model 推荐的论文中: On the theory of stochastic processes, with particular reference to applications
. 我大致看了一下, 哈哈, 没太看明白, 不过想到这个不是我关注的重点, 因此 pass. 由于我们无法直接推断 �(��−1∣��)q(xt−1∣xt), 因此我们将使用深度学习模型 ��pθ 去拟合分布 �(��−1∣��)q(xt−1∣xt), 模型参数为 �θ:
��(�0:�)=�(��)∏�=1���(��−1∣��)��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�))pθ(x0:T)pθ(xt−1∣xt)=p(xT)t=1∏Tpθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
注意到, 虽然我们无法直接求得 �(��−1∣��)q(xt−1∣xt) (注意这里是 �q 而不是模型 ��pθ), 但在知道 �0x0 的情况下, 可以通过贝叶斯公式得到 �(��−1∣��,�0)q(xt−1∣xt,x0) 为:
�(��−1∣��,�0)=�(��−1;�~(��,�0),�~��)q(xt−1∣xt,x0)=N(xt−1;μ~(xt,x0),β~tI)
推导过程如下:
�(��−1∣��,�0)=�(��∣��−1,�0)�(��−1∣�0)�(��∣�0)∝exp(−12((��−����−1)2��+(��−1−�ˉ�−1�0)21−�ˉ�−1−(��−�ˉ��0)21−�ˉ�))=exp(−12(��2−2������−1+����−12��+��−12−2�ˉ�−1�0��−1+�ˉ�−1�021−�ˉ�−1−(��−�ˉ��0)21−�ˉ�))=exp(−12((����+11−�ˉ�−1)��−12⏟��−1 方差 −(2������+2�ˉ�−11−�ˉ�−1�0)��−1⏟��−1 均值 +�(��,�0)⏟与 ��−1 无关 ))q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)∝exp(−21(βt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2))=exp(−21(βtxt2−2αtxtxt−1+αtxt−12+1−αˉt−1xt−12−2αˉt−1x0xt−1+αˉt−1x02−1−αˉt(xt−αˉtx0)2))=exp(−21(xt−1 方差 (βtαt+1−αˉt−11)xt−12−xt−1 均值 (βt2αtxt+1−αˉt−12αˉt−1x0)xt−1+与 xt−1 无关 C(xt,x0)))
上面推导过程中, 通过贝叶斯公式巧妙的将逆向过程转换为前向过程, 且最终得到的概率密度函数和高斯概率密度函数的指数部分 exp(−(�−�)22�2)=exp(−12(1�2�2−2��2�+�2�2))exp(−2σ2(x−μ)2)=exp(−21(σ21x2−σ22μx+σ2μ2)) 能对应, 即有:
通过公式 (8) 和公式 (9), 我们能得到 �(��−1∣��,�0)q(xt−1∣xt,x0) 的分布. 此外由于公式 (3) 揭示的 ��xt 和 �0x0 之间的关系: ��=�ˉ��0+1−�ˉ��ˉ�xt=αˉtx0+1−αˉtϵˉt, 可以得到
�0=1�ˉ�(��−1−�ˉ���)x0=αˉt1(xt−1−αˉtϵt)
代入公式 (9) 中得到:
补充一下公式 (11) 的详细推导过程:
前面说到, 我们将使用深度学习模型 ��pθ 去拟合逆向过程的分布 �(��−1∣��)q(xt−1∣xt), 由上面公式知 ��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�))pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t)), 我们希望训练模型 ��(��,�)μθ(xt,t) 以预估 �~�=1��(��−1−��1−�ˉ���)μ~t=αt1(xt−1−αˉt1−αtϵt). 由于 ��xt 在训练阶段会作为输入, 因此它是已知的, 我们可以转而让模型去预估噪声 ��ϵt, 即令:
��(��,�)=1��(��−1−��1−�ˉ���(��,�))Thus ��−1=�(��−1;1��(��−1−��1−�ˉ���(��,�)),��(��,�))μθ(xt,t)Thus xt−1=αt1(xt−1−αˉt1−αtϵθ(xt,t))=N(xt−1;αt1(xt−1−αˉt1−αtϵθ(xt,t)),Σθ(xt,t))
前面谈到, 逆向阶段让模型去预估噪声 ��(��,�)ϵθ(xt,t), 那么应该如何设计 Loss 函数 ? 我们的目标是在真实数据分布下, 最大化模型预测分布的对数似然, 即优化在 �0∼�(�0)x0∼q(x0) 下的 ��(�0)pθ(x0) 交叉熵:
�=��(�0)[−log��(�0)]L=Eq(x0)[−logpθ(x0)]
和 变分自动编码器 VAE 类似, 使用 Variational Lower Bound 来优化: −log��(�0)−logpθ(x0) :
对公式 (15) 左右两边取期望 ��(�0)Eq(x0), 利用到重积分中的 Fubini 定理 可得:
����=��(�0)(��(�1:�∣�0)[log�(�1:�∣�0)��(�0:�)])=��(�0:�)[log�(�1:�∣�0)��(�0:�)]⏟Fubini定理 ≥��(�0)[−log��(�0)]LVLB=Fubini定理 Eq(x0)(Eq(x1:T∣x0)[logpθ(x0:T)q(x1:T∣x0)])=Eq(x0:T)[logpθ(x0:T)q(x1:T∣x0)]≥Eq(x0)[−logpθ(x0)]
因此最小化 ����LVLB 就可以优化目标函数 �L. 之后对 ����LVLB 做进一步的推导, 这部分的详细推导见上面的参考文章, 最终的结论是:
����=��+��−1+…+�0��=���(�(��∣�0)∣∣��(��))��=���(�(��∣��−1,�0)∣∣��(��∣��+1));1≤�≤�−1�0=−log��(�0∣�1)LVLBLTLtL0=LT+LT−1+…+L0=DKL(q(xT∣x0)∣∣pθ(xT))=DKL(q(xt∣xt−1,x0)∣∣pθ(xt∣xt+1));1≤t≤T−1=−logpθ(x0∣x1)
最终是优化两个高斯分布 �(��∣��−1,�0)=�(��−1;�~(��,�0),�~��)q(xt∣xt−1,x0)=N(xt−1;μ~(xt,x0),β~tI) 与 ��(��∣��+1)=�(��−1;��(��,�),Σ�)pθ(xt∣xt+1)=N(xt−1;μθ(xt,t),Σθ) (此为模型预估的分布)之间的 KL 散度. 由于多元高斯分布的 KL 散度存在闭式解, 详见: Multivariate_normal_distributions, 从而可以得到:
��=��0,�[12∥��(��,�)∥22∥�~�(��,�0)−��(��,�)∥2]=��0,�[12∥��∥22∥1��(��−1−��1−�ˉ���)−1��(��−1−��1−�ˉ���(��,�))∥2]=��0,�[(1−��)22��(1−�ˉ�)∥��∥22∥��−��(��,�)∥2];其中��为高斯噪声,��为模型学习的噪声=��0,�[(1−��)22��(1−�ˉ�)∥��∥22∥��−��(�ˉ��0+1−�ˉ���,�)∥2]Lt=Ex0,ϵ[2∥Σθ(xt,t)∥221∥μ~t(xt,x0)−μθ(xt,t)∥2]=Ex0,ϵ[2∥Σθ∥221∥αt1(xt−1−αˉt1−αtϵt)−αt1(xt−1−αˉt1−αtϵθ(xt,t))∥2]=Ex0,ϵ[2αt(1−αˉt)∥Σθ∥22(1−αt)2∥ϵt−ϵθ(xt,t)∥2];其中ϵt为高斯噪声,ϵθ为模型学习的噪声=Ex0,ϵ[2αt(1−αˉt)∥Σθ∥22(1−αt)2∥ϵt−ϵθ(αˉtx0+1−αˉtϵt,t)∥2]
DDPM 将 Loss 简化为如下形式:
��simple =��0,��[∥��−��(�ˉ��0+1−�ˉ���,�)∥2]Ltsimple =Ex0,ϵt[∥∥ϵt−ϵθ(αˉtx0+1−αˉtϵt,t)∥∥2]
因此 Diffusion 模型的目标函数即是学习高斯噪声 ��ϵt 和 ��ϵθ (来自模型输出) 之间的 MSE loss.
最终 DDPM 的算法流程如下:
训练阶段重复如下步骤:
逆向阶段采用如下步骤进行采样:
DDPM 文章以及代码的相关信息如下:
本文以分析 Tensorflow 源码为主, Pytorch 版本的代码和 Tensorflow 版本的实现逻辑大体不差的, 变量名字啥的都类似, 阅读起来不会有啥门槛. Tensorlow 源码对 Diffusion 模型的实现位于 diffusion_utils_2.py, 模型本身的分析以该文件为主.
以 CIFAR 数据集为例.
在 run_cifar.py 中进行前向传播计算 Loss:
training_losses
定义在 GaussianDiffusion2 中, 计算噪声间的 MSE Loss.进入 GaussianDiffusion2 中, 看到初始化函数中定义了诸多变量, 我在注释中使用公式的方式进行了说明:
下面进入到 training_losses
函数中:
self.model_mean_type
默认是 eps
, 模型学习的是噪声, 因此 target
是第 6 行定义的 noise
, 即 ��ϵtself.q_sample
计算 ��xt, 即公式 (3) ��=�ˉ��0+1−�ˉ���xt=αˉtx0+1−αˉtϵtdenoise_fn
是定义在 unet.py 中的 UNet
模型, 只需知道它的输入和输出大小相同; 结合第 9 行得到的 ��xt, 得到模型预估的噪声: ��(�ˉ��0+1−�ˉ���,�)ϵθ(αˉtx0+1−αˉtϵt,t)上面第 9 行定义的 self.q_sample
详情如下:
q_sample
已经介绍过, 不多说._extract
在代码中经常被使用到, 看到它只需知道它是用来提取系数的即可. 引入输入是一个 Batch, 里面的每个样本都会随机采样一个 time step �t, 因此需要使用 tf.gather
来将 ��ˉαtˉ 之类选出来, 然后将系数 reshape 为 [B, 1, 1, ....]
的形式, 目的是为了利用 broadcasting 机制和 ��xt 这个 Tensor 相乘.前向的训练阶段代码实现非常简单, 下面看逆向阶段
逆向阶段代码定义在 GaussianDiffusion2 中:
self.p_sample
就是公式 (6) ��(��−1∣��)=�(��−1;��(��,�),Σ�(��,�))pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t)) 的过程, 使用模型来预估 ��(��,�)μθ(xt,t) 以及 Σ�(��,�)Σθ(xt,t)denoise_fn
在前面说过, 是定义在 unet.py 中的 UNet
模型; img_
表示 ��xt.noise_fn
则默认是 tf.random_normal
, 用于生成高斯噪声.进入 p_sample
函数:
self.p_mean_variance
生成 ��(��,�)μθ(xt,t) 以及 log(Σ�(��,�))log(Σθ(xt,t)), 其中 Σ�(��,�)Σθ(xt,t) 通过计算 �~�β~t 得到.进入 self.p_mean_variance
函数:
denoise_fn
, 通过输入 ��xt, 输出得到噪声 ��ϵtself.model_var_type
默认为 fixedlarge
, 但我当时看 fixedsmall
比较爽, 因此 model_variance
和 model_log_variance
分别为 �~�=1−�ˉ�−11−�ˉ�⋅��β~t=1−αˉt1−αˉt−1⋅βt (见公式 8), 以及 log�~�logβ~tself._predict_xstart_from_eps
函数, 利用公式 (10) 得到 �0=1�ˉ�(��−1−�ˉ���)x0=αˉt1(xt−1−αˉtϵt)self.q_posterior_mean_variance
通过公式 (9) 得到 ��(��,�0)=��(1−�ˉ�−1)1−�ˉ���+�ˉ�−1��1−�ˉ��0μθ(xt,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0self._predict_xstart_from_eps
函数详情如下:
self.q_posterior_mean_variance
函数详情如下:
本文分析了扩散模型 DDPM 算法,对原理以及代码进行了剖析,公式比较多,手推一遍再结合代码分析会有更深的体会。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。