当前位置:   article > 正文

Diffusion学习笔记_diffusion 模型笔记

diffusion 模型笔记


一些资料

主要部分是对该文章的翻译:
What are Diffusion Models?

提出扩散模型的文章:链接
DDPM:链接

一、Diffusion 模型概述

扩散模型受到非平衡热力学的启发。他们定义了扩散步骤的马尔可夫链,以缓慢地将随机噪声添加到数据中,然后学习反转扩散过程,从噪声中构建所需的数据样本。与VAE或流动模型不同,扩散模型是通过固定的过程学习的,并且潜在变量具有高维度(与原始数据相同)。
Diffusion Model的训练过程通常分为两个阶段:前向扩散过程和后向逆扩散过程。

1、前向扩散过程

这个阶段模拟了从真实数据到随机噪声的过程。给定一个原始数据样本,通过连续的T个时间步长,在每个时间步长上都添加一些随机噪声。这样,随着时间的推移,原始数据逐渐被随机噪声所覆盖,最终形成一个完全由噪声构成的样本。
x 0 x_0 x0为从实际数据分布 x 0 ∼ q ( x ) x_0\sim q\left( x\right) x0q(x)采样得到的一个样本点。在前向扩散的 T T T步过程中,我们通过逐渐向 x 0 x_0 x0中添加少量的高斯噪声,生成了一系列添加了噪声的样本: x 1 , ⋯   , x T x_1,\cdots ,x_T x1,,xT。每一步的步长由 { β t ∈ ( 0 , 1 ) } t = 1 T \left\lbrace \beta _t \in \left( 0,1\right) \right\rbrace ^T_{t=1} {βt(0,1)}t=1T决定。
每一步的转移概率服从高斯分布: q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q\left( x_t | x_{t-1}\right)=\mathcal{N}\left( x_t;\sqrt{1-\beta_t}x_{t-1},\beta_tI\right) q(xtxt1)=N(xt;1βt xt1,βtI)经过T步有: q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q\left( x{1:T} | x_0\right)=\prod_{t=1}^Tq\left( x_t|x_{t-1}\right) q(x1:Tx0)=t=1Tq(xtxt1)在这个过程中,随着t的增长,样本 x 0 x_0 x0逐渐失去了精确性。当 T → ∞ T\to \infin T时, x t x_t xt服从各向同性高斯分布(an Isotropic Gaussian distribution)。
各向同性高斯分布可参考各向同性高斯分布
上述过程的一个很好的属性就是,我们可以使用重参数技巧在任意时间步 t t t以一种“closed form”对 x t x_t xt进行采样。重参数技巧使得高斯分布的随机性转移到一个参数上

重参数技巧(reparameterization trick):
z ∼ N ( z ; μ , σ I ) z\sim\mathcal{N}\left(z;\mu,\sigma I\right) zN(z;μ,σI),则 z = μ + σ ⊙ ϵ z=\mu+\sigma\odot\epsilon z=μ+σϵ,其中 ϵ ∼ N ( 0 , I ) , ⊙ \epsilon\sim\mathcal{N}\left( 0,I\right),\odot ϵN(0,I)表示element-wise 乘积。

α t = 1 − β t \alpha_t=1-\beta_t αt=1βt并且 α t ‾ = ∏ i = 1 t α i \overline{\alpha_t}=\prod_{i=1}^t\alpha_i αt=i=1tαi,则
x t = α t x t − 1 + 1 − α t ϵ t − 1 = α t ( α t − 1 x t − 2 + 1 − α t − 1 ϵ t − 2 ) + 1 − α t ϵ t − 1 = α t α t − 1 x t − 2 + α t 1 − α t − 1 ϵ t − 2 + 1 − α t ϵ t − 1 = α t α t − 1 x t − 2 + α t ( 1 − α t − 1 ) + 1 − α t ϵ ‾ t − 2 = α t α t − 1 x t − 2 + 1 − α t α t − 1 ϵ ˉ t − 2 = ⋯ = α ˉ t x 0 + 1 − α ˉ t ϵ q ( x t ∣ x 0 ) = N ( x t ; α ‾ x 0 , ( 1 − α ‾ t ) I ) (*) \begin{aligned} x_t=& \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\epsilon_{t-1}\\ =& \sqrt{\alpha_t}\left( \sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2}\right)+\sqrt{1-\alpha_t}\epsilon_{t-1}\\ =&\sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}}\epsilon_{t-2} +\sqrt{1-\alpha_t}\epsilon_{t-1} \\ =& \sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{\alpha_t\left(1-\alpha_{t-1}\right)+1-\alpha_t}\overline{\epsilon}_{t-2}\\ =&\sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\bar{\epsilon}_{t-2} \\ =&\cdots \\ =&\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon \tag*{(*)}\\ q(x_t|x_0)=&\mathcal{N}\left( x_t;\sqrt{\overline{\alpha}}x_0, \left( 1-\overline{\alpha}_t\right)I\right) \end{aligned}

\begin{aligned} x_t=& \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\epsilon_{t-1}\\ =& \sqrt{\alpha_t}\left( \sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2}\right)+\sqrt{1-\alpha_t}\epsilon_{t-1}\\ =&\sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}}\epsilon_{t-2} +\sqrt{1-\alpha_t}\epsilon_{t-1} \\ =& \sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{\alpha_t\left(1-\alpha_{t-1}\right)+1-\alpha_t}\overline{\epsilon}_{t-2}\\ =&\sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\bar{\epsilon}_{t-2} \\ =&\cdots \\ =&\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon \tag*{(*)}\\ q(x_t|x_0)=&\mathcal{N}\left( x_t;\sqrt{\overline{\alpha}}x_0, \left( 1-\overline{\alpha}_t\right)I\right) \end{aligned}
xt=======q(xtx0)=αt xt1+1αt ϵt1αt (αt1 xt2+1αt1 ϵt2)+1αt ϵt1αtαt1 xt2+αt 1αt1 ϵt2+1αt ϵt1αtαt1 xt2+αt(1αt1)+1αt ϵt2αtαt1 xt2+1αtαt1 ϵˉt2αˉt x0+1αˉt ϵN(xt;α x0,(1αt)I)*其中, α ˉ t = ∏ t = 1 T α t , 且 ϵ t − 1 , ϵ t − 2 , ⋯ ∼ N ( 0 , I ) \bar{\alpha}_t=\prod_{t=1}^T\alpha_t,且\epsilon_{t-1},\epsilon_{t-2},\cdots \sim \mathcal{N}\left( 0,I\right) αˉt=t=1Tαt,ϵt1,ϵt2,N(0,I)

解释上式第三行到第四行的转换:
两个独立高斯分布的和,仍然服从高斯分布:
x 1 x_1 x1 x 2 x_2 x2为相互独立的两个随机变量。 x 1 ∼ N ( 0 , σ 1 2 I ) , x 2 ∼ N ( 0 , σ 2 2 I ) x_1\sim \mathcal{N}\left( 0,\sigma_1^2I\right),x_2\sim\mathcal{N}\left( 0,\sigma_2^2I\right) x1N(0,σ12I),x2N(0,σ22I),则两者之和 x 3 = x 1 + x 2 x_3=x_1+x_2 x3=x1+x2满足 x 3 ∼ N ( 0 , ( σ 1 2 + σ 2 2 ) I ) x_3\sim\mathcal{N}\left( 0,\left(\sigma_1^2 + \sigma_2^2\right)I\right) x3N(0,(σ12+σ22)I)
因此,由 ϵ t − 1 , ϵ t − 2 , ⋯ ∼ N ( 0 , I ) \epsilon_{t-1},\epsilon_{t-2},\cdots \sim \mathcal{N}\left( 0,I\right) ϵt1,ϵt2,N(0,I),有
α t 1 − α t − 1 ϵ t − 2 ∼ N ( 0 , α t ( 1 − α t − 1 ) I ) , 1 − α t ϵ t − 1 ∼ N ( 0 , ( 1 − α t ) I ) \sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}}\epsilon_{t-2}\sim\mathcal{N}\left( 0,\alpha_t\left(1-\alpha_{t-1}\right)I\right),\sqrt{1-\alpha_t}\epsilon_{t-1}\sim\mathcal{N}\left( 0,\left(1-\alpha_t\right)I\right) αt 1αt1 ϵt2N(0,αt(1αt1)I),1αt ϵt1N(0,(1αt)I)因此有 ϵ ‾ t − 2 ∼ N ( 0 , ( 1 − α t α t − 1 ) I ) \overline{\epsilon}_{t-2}\sim\mathcal{N}\left( 0,\left(1-\alpha_t\alpha_{t-1}\right)I\right) ϵt2N(0,(1αtαt1)I)

2、后向逆扩散过程

在后向过程中,模型学习如何通过去噪函数逐步地从随机噪声中恢复原始数据。这个过程可以看作是前向过程的反向操作,即逐步减少噪声并增加对原始数据的恢复。
如果我们能够反向进行上面的操作并且从 q ( x t − 1 ∣ x t ) q\left(x_{t-1}|x_t\right) q(xt1xt)中采样,我们就能够从高斯噪声 x t ∼ N ( 0 , I ) x_t\sim\mathcal{N}\left(0,I\right) xtN(0,I)中还原原始数据。注意到 β t \beta_t βt足够小,这个逆过程仍为高斯分布。我们无法容易地估计 q ( x t − 1 ∣ x t ) q\left(x_{t-1}|x_t\right) q(xt1xt),因为这需要得知整个数据集的分布,因此,我们采取学习模型 p θ p_\theta pθ的方式来估计这个条件概率。
p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta\left(x_{0:T}\right)=p\left(x_T\right)\prod_{t=1}^Tp_\theta\left(x_{t-1}|x_t\right)\\ p_\theta\left(x_{t-1}|x_t\right)=\mathcal{N}\left(x_{t-1};\mu_\theta\left(x_t,t\right),\Sigma_\theta\left(x_t,t\right)\right) pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))
值得注意:
当条件中包含 x 0 x_0 x0时,后向逆扩散过程的条件概率是可以求得的。
由贝叶斯公式,有:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) ∝ exp ⁡ ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = exp ⁡ ( − 1 2 ( x t 2 − 2 α t x t x t − 1 + α t x t − 1 2 β t + x t − 1 2 − 2 α ˉ t − 1 x 0 x t − 1 + α ˉ t − 1 x 0 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = exp ⁡ ( − 1 2 ( ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 α ˉ t − 1 1 − α ˉ t − 1 x 0 ) x t − 1 + C ( x t , x 0 ) ) ) q(xt1|xt,x0)=q(xt|xt1,x0)q(xt1|x0)q(xt|x0)exp(12((xtαtxt1)2βt+(xt1ˉαt1x0)21ˉαt1(xtˉαtx0)21ˉαt))=exp(12(x2t2αtxtxt1+αtx2t1βt+x2t12ˉαt1x0xt1+ˉαt1x201ˉαt1(xtˉαtx0)21ˉαt))=exp(12((αtβt+11ˉαt1)x2t1(2αtβtxt+2ˉαt11ˉαt1x0)xt1+C(xt,x0)))

q(xt1xt,x0)===q(xtxt1,x0)q(xtx0)q(xt1x0)exp(21(βt(xtαt xt1)2+1αˉt1(xt1αˉt1 x0)21αˉt(xtαˉt x0)2))exp(21(βtxt22αt xtxt1+αtxt12+1αˉt1xt122αˉt1 x0xt1+αˉt1x021αˉt(xtαˉt x0)2))exp(21((βtαt+1αˉt11)xt12(βt2αt xt+1αˉt12αˉt1 x0)xt1+C(xt,x0)))
其中, C ( x t , x 0 ) C\left( x_t,x_0 \right) C(xt,x0) x t − 1 x_{t-1} xt1无关,因此略去细节。遵循高斯密度函数的形式,可以经过整理得出均值和方差:
β ~ t = 1 ( α t β t + 1 1 − α ˉ t − 1 ) = 1 ( α t − α ˉ t + β t β t ( 1 − α ˉ t − 1 ) ) = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t μ ~ t ( x t , x 0 ) = ( α t β t x t + α ˉ t − 1 1 − α ˉ t − 1 x 0 ) ( α t β t + 1 1 − α ˉ t − 1 ) = ( α t β t x t + α ˉ t − 1 1 − α ˉ t − 1 x 0 ) ⋅ 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 ˜βt=1(αtβt+11ˉαt1)=1(αtˉαt+βtβt(1ˉαt1))=1ˉαt11ˉαtβt˜μt(xt,x0)=(αtβtxt+ˉαt11ˉαt1x0)(αtβt+11ˉαt1)=(αtβtxt+ˉαt11ˉαt1x0)1ˉαt11ˉαtβt=αt(1ˉαt1)1ˉαtxt+ˉαt1βt1ˉαtx0
β~t=μ~t(xt,x0)===(βtαt+1αˉt11)1=(βt(1αˉt1)αtαˉt+βt)1=1αˉt1αˉt1βt(βtαt+1αˉt11)(βtαt xt+1αˉt1αˉt1 x0)(βtαt xt+1αˉt1αˉt1 x0)1αˉt1αˉt1βt1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0

(*)式即 x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t=\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon xt=αˉt x0+1αˉt ϵ代入上式消去 x 0 x_0 x0,有:
μ ~ t = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) \tilde{\mu}_t=\frac{1}{\sqrt{\alpha_t}}\left( x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t\right) μ~t=αt 1(xt1αˉt 1αtϵt)

二、训练过程

在训练时,需要使用神经网络来拟合从当前噪声状态一步步回到原始数据的条件概率分布。首先自行复习一下熵,交叉熵和KL散度

1、变分下界

这个过程与VAE十分类似,因此我们可以使用变分下界(Variational lower bound,VLB)来优化负对数似然(negative log-likelihood)。
− log ⁡ p θ ( x 0 ) ≤ − log ⁡ p θ ( x 0 ) + D K L ( q ( x 1 : T ∣ x 0 ) ∥ p θ ( x 1 : T ∣ x 0 ) ) = − log ⁡ p θ ( x 0 ) + E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) / p θ ( x 0 ) ] = − log ⁡ p θ ( x 0 ) + E q [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) + log ⁡ p θ ( x 0 ) ] = E q [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] 令 L V L B = E q ( x 0 : T ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] logpθ(x0)logpθ(x0)+DKL(q(x1:T|x0)pθ(x1:T|x0))=logpθ(x0)+Ex1:Tq(x1:T|x0)[logq(x1:T|x0)pθ(x0:T)/pθ(x0)]=logpθ(x0)+Eq[logq(x1:T|x0)pθ(x0:T)+logpθ(x0)]=Eq[logq(x1:T|x0)pθ(x0:T)]LVLB=Eq(x0:T)[logq(x1:T|x0)pθ(x0:T)]

logpθ(x0)===LVLB=logpθ(x0)+DKL(q(x1:Tx0)pθ(x1:Tx0))logpθ(x0)+Ex1:Tq(x1:Tx0)[logpθ(x0:T)/pθ(x0)q(x1:Tx0)]logpθ(x0)+Eq[logpθ(x0:T)q(x1:Tx0)+logpθ(x0)]Eq[logpθ(x0:T)q(x1:Tx0)]Eq(x0:T)[logpθ(x0:T)q(x1:Tx0)]
也可以直接使用Jensen不等式:

Jenson不等式:若 f ( x ) f\left(x\right) f(x)是区间 [ a , b ] \left[a,b\right] [a,b]上的凸函数,则对任意的 x 1 , x 2 , ⋯   , x n ∈ [ a , b ] x_1,x_2,\cdots,x_n\in \left[a,b\right] x1,x2,,xn[a,b],则有:
f ( ∑ i = 1 n x i n ) ≥ ∑ i = 1 n f ( x i ) n f\left(\sum_{i=1}^{n}\frac{x_i}{n}\right)\ge\frac{\sum_{i=1} ^{n}f\left(x_i\right)}{n} f(i=1nnxi)ni=1nf(xi)
f ( E ( x ) ) ≥ E ( f ( x ) ) f\left(\mathbb{E}\left(x\right)\right)\ge\mathbb{E}\left(f(x)\right) f(E(x))E(f(x))

因此,交叉熵损失(Cross Entropy)可以像下面这样:
L C E = − E q ( x 0 ) log ⁡ p θ ( x 0 ) = − E q ( x 0 ) log ⁡ ( ∫ p θ ( x 0 : T ) d x 1 : T ) = − E q ( x 0 ) log ⁡ ( ∫ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) d x 1 : T ) = − E q ( x 0 ) log ⁡ ( E q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ) ≤ − E q ( x 0 : T ) log ⁡ ( p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ) = E q ( x 0 : T ) log ⁡ ( q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ) = L V L B LCE=Eq(x0)logpθ(x0)=Eq(x0)log(pθ(x0:T)dx1:T)=Eq(x0)log(q(x1:T|x0)pθ(x0:T)q(x1:T|x0)dx1:T)=Eq(x0)log(Eq(x1:T|x0)pθ(x0:T)q(x1:T|x0))Eq(x0:T)log(pθ(x0:T)q(x1:T|x0))=Eq(x0:T)log(q(x1:T|x0)pθ(x0:T))=LVLB

LCE======Eq(x0)logpθ(x0)Eq(x0)log(pθ(x0:T)dx1:T)Eq(x0)log(q(x1:Tx0)q(x1:Tx0)pθ(x0:T)dx1:T)Eq(x0)log(Eq(x1:Tx0)q(x1:Tx0)pθ(x0:T))Eq(x0:T)log(q(x1:Tx0)pθ(x0:T))Eq(x0:T)log(pθ(x0:T)q(x1:Tx0))LVLB倒数第三行用到了Jensen不等式,对凸函数而言,期望的函数大于函数的期望,由于前面有负号,因此这里是小于号。于是,我们只需要优化 L V L B L_{VLB} LVLB(很有意思,变分界损失( L V L B L_{VLB} LVLB)实际上是交叉熵损失( L C E L_{CE} LCE)的界)就可以间接地压缩 L C E L_{CE} LCE了。
下面,为了能够更好地计算 L V L B L_{VLB} LVLB,我们将其转化为多个KL散度和熵的和。
L V L B = E q ( x 0 : T ) log ⁡ ( q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ) = E q log ⁡ ( ∏ t = 1 T q ( x t ∣ x t − 1 ) p θ ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ) = E q [ − log ⁡ p θ ( x T ) + ∑ t = 1 T log ⁡ q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + ∑ t = 2 T log ⁡ q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + log ⁡ q ( x T ∣ x 0 ) q ( x 1 ∣ x 0 ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ log ⁡ q ( x T ∣ x 0 ) p θ ( x 0 ∣ x 1 ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) − log ⁡ p θ ( x T ) ] = E q [ D K L ( q ( x T ∣ x 0 ) ∣ ∣ p θ ( x 0 ∣ x 1 ) ) ⏟ L t + ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) ⏟ L t − 1 − log ⁡ p θ ( x T ) ⏟ L 0 ] LVLB=Eq(x0:T)log(q(x1:T|x0)pθ(x0:T))=Eqlog(Tt=1q(xt|xt1)pθ(xT)Tt=1pθ(xt1|xt))=Eq[logpθ(xT)+Tt=1logq(xt|xt1)pθ(xt1|xt)]=Eq[logpθ(xT)+Tt=2logq(xt|xt1)pθ(xt1|xt)+logq(x1|x0)pθ(x0|x1)]=Eq[logpθ(xT)+Tt=2logq(xt1|xt,x0)pθ(xt1|xt)q(xt|x0)q(xt1|x0)+logq(x1|x0)pθ(x0|x1)]=Eq[logpθ(xT)+Tt=2logq(xt1|xt,x0)pθ(xt1|xt)+Tt=2logq(xt|x0)q(xt1|x0)+logq(x1|x0)pθ(x0|x1)]=Eq[logpθ(xT)+Tt=2logq(xt1|xt,x0)pθ(xt1|xt)+logq(xT|x0)q(x1|x0)+logq(x1|x0)pθ(x0|x1)]=Eq[logq(xT|x0)pθ(x0|x1)+Tt=2logq(xt1|xt,x0)pθ(xt1|xt)logpθ(xT)]=Eq[DKL(q(xT|x0)||pθ(x0|x1))Lt+Tt=2DKL(q(xt1|xt,x0)||pθ(xt1|xt))Lt1logpθ(xT)L0]
LVLB=========Eq(x0:T)log(pθ(x0:T)q(x1:Tx0))Eqlog(pθ(xT)t=1Tpθ(xt1xt)t=1Tq(xtxt1))Eq[logpθ(xT)+t=1Tlogpθ(xt1xt)q(xtxt1)]Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xtxt1)+logpθ(x0x1)q(x1x0)]Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)q(xt1x0)q(xtx0)+logpθ(x0x1)q(x1x0)]Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)+t=2Tlogq(xt1x0)q(xtx0)+logpθ(x0x1)q(x1x0)]Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)+logq(x1x0)q(xTx0)+logpθ(x0x1)q(x1x0)]Eq[logpθ(x0x1)q(xTx0)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)logpθ(xT)]Eq Lt DKL(q(xTx0)∣∣pθ(x0x1))+t=2TLt1 DKL(q(xt1xt,x0)∣∣pθ(xt1xt))L0 logpθ(xT)

使用上面几个符号来表达 L V L B L_{VLB} LVLB更加简易:
L V L B = L 0 + L 1 + ⋯ + L t + ⋯ + L T − 1 + L T 其中, L T = D K L ( q ( x T ∣ x 0 ) ∣ ∣ p θ ( x 0 ∣ x 1 ) ) L 0 = log ⁡ p θ ( x T ) L t = D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) , t ∈ [ 1 , T − 1 ] LVLB=L0+L1++Lt++LT1+LTLT=DKL(q(xT|x0)||pθ(x0|x1))L0=logpθ(xT)Lt=DKL(q(xt1|xt,x0)||pθ(xt1|xt)),t[1,T1]
LVLB=其中,LT=L0=Lt=L0+L1++Lt++LT1+LTDKL(q(xTx0)∣∣pθ(x0x1))logpθ(xT)DKL(q(xt1xt,x0)∣∣pθ(xt1xt)),t[1,T1]

上式中的每一个KL项都是两个高斯分布之间的比较,因此可以以“closed form”求解。 L T L_T LT是一个常量,没有可学习的部分,因此在训练中直接忽略。(Ho et al. 2020)使用了一个从 N ( x 0 ; μ θ ( x 1 , 1 ) , Σ θ ( x 1 , 1 ) ) \mathcal{N}\left(x_0;\mu_\theta\left(x_1,1\right),\Sigma_\theta\left(x_1,1\right)\right) N(x0;μθ(x1,1),Σθ(x1,1))中产生的单独的离散decoder来给 L 0 L_0 L0建模。

2、 L t L_t Lt的重参数化

为了在后向逆过程中使用神经网络拟合条件概率分布: p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t − 1 , t ) , Σ θ ( x t , t ) ) p_\theta\left(x_{t-1}|x_t\right)=\mathcal{N}\left(x_{t-1};\mu_\theta\left(x_{t-1},t\right),\Sigma_\theta\left(x_t,t\right)\right) pθ(xt1xt)=N(xt1;μθ(xt1,t),Σθ(xt,t)),我们需要训练一个 μ θ \mu_\theta μθ来拟合 μ ~ t = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) \tilde{\mu}_t=\frac{1}{\sqrt{\alpha_t}}\left( x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_t\right) μ~t=αt 1(xt1αˉt 1αtϵt) x t x_t xt作为输入是已知的,因此我们可以对高斯噪声进行重参数化,在时间步 t t t,从 x t x_t xt中预测 ϵ t \epsilon_t ϵt
μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) 因此, p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) , Σ θ ( x t , t ) ) μθ(xt,t)=1αt(xt1αt1ˉαtϵθ(xt,t))pθ(xt1|xt)=N(xt1;1αt(xt1αt1ˉαtϵθ(xt,t)),Σθ(xt,t))

μθ(xt,t)=因此,pθ(xt1xt)=αt 1(xt1αˉt 1αtϵθ(xt,t))N(xt1;αt 1(xt1αˉt 1αtϵθ(xt,t)),Σθ(xt,t))
损失 L t L_t Lt被重参数化,用来缩小 μ θ \mu_\theta μθ μ ~ t \tilde{\mu}_t μ~t之间的差距:
L t = E x 0 , ϵ [ 1 2 ∥ Σ θ ( x t , t ) ∥ 2 2 ∥ μ θ ( x t , x 0 ) − μ ~ t ( x t , x 0 ) ∥ 2 ] = E x 0 , ϵ [ 1 2 ∥ Σ θ ( x t , t ) ∥ 2 2 ∥ 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) − 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) ∥ 2 ] = E x 0 , ϵ [ ( 1 − α t ) 2 2 α t ( 1 − α ˉ t ) ∥ Σ θ ∥ 2 ∥ ϵ θ ( x t , t ) − ϵ t ∥ 2 ] = E x 0 , ϵ [ ( 1 − α t ) 2 2 α t ( 1 − α ˉ t ) ∥ Σ θ ∥ 2 ∥ ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ t , t ) − ϵ t ∥ 2 ] Lt=Ex0,ϵ[12Σθ(xt,t)22μθ(xt,x0)˜μt(xt,x0)2]=Ex0,ϵ[12Σθ(xt,t)221αt(xt1αt1ˉαtϵθ(xt,t))1αt(xt1αt1ˉαtϵt)2]=Ex0,ϵ[(1αt)22αt(1ˉαt)Σθ2ϵθ(xt,t)ϵt2]=Ex0,ϵ[(1αt)22αt(1ˉαt)Σθ2ϵθ(ˉαtx0+1ˉαtϵt,t)ϵt2]
Lt====Ex0,ϵ[2Σθ(xt,t)221μθ(xt,x0)μ~t(xt,x0)2]Ex0,ϵ[2Σθ(xt,t)221αt 1(xt1αˉt 1αtϵθ(xt,t))αt 1(xt1αˉt 1αtϵt)2]Ex0,ϵ[2αt(1αˉt)Σθ2(1αt)2ϵθ(xt,t)ϵt2]Ex0,ϵ[2αt(1αˉt)Σθ2(1αt)2ϵθ(αˉt x0+1αˉt ϵt,t)ϵt2]

总结

推导过程不是很复杂,主要是找对什么部分是可学习的,训练就揪着这一部分可劲造。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/677792
推荐阅读
相关标签
  

闽ICP备14008679号