赞
踩
奠基性的工作:
条件概率的一般形式
P
(
B
,
C
∣
A
)
=
P
(
B
∣
A
)
P
(
C
∣
A
,
B
)
P(B,C|A)=P(B|A)P(C|A,B)
P(B,C∣A)=P(B∣A)P(C∣A,B)
基于马尔可夫假设的条件概率
假设马尔可夫链关系
A
→
B
→
C
A\to B\to C
A→B→C,有
P
(
A
,
B
,
C
)
=
P
(
C
∣
B
)
P
(
B
∣
A
)
P
(
A
)
P(A,B,C)=P(C|B)P(B|A)P(A)
P(A,B,C)=P(C∣B)P(B∣A)P(A)
高斯分布的KL散度
对于两个单一变量的高斯分布p和q而言,他们的KL散度满足
K
L
(
N
(
μ
1
,
σ
1
2
)
,
N
(
μ
2
,
σ
2
2
)
)
=
log
σ
2
σ
1
−
1
2
+
σ
1
2
+
(
μ
1
−
μ
2
)
2
2
σ
2
2
KL(\mathcal{N}(\mu_1,\sigma_1^2),\mathcal{N}(\mu_2,\sigma_2^2))=\log\frac{\sigma_2}{\sigma_1}-\frac{1}{2}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}
KL(N(μ1,σ12),N(μ2,σ22))=logσ1σ2−21+2σ22σ12+(μ1−μ2)2
推导详见CSDN博客
参数重整化
若希望从高斯分布 N ( μ , σ 2 ) \mathcal{N}(\mu,\sigma^2) N(μ,σ2)中采样,可以先从标准分布 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1)得到 z z z,得到 σ ⋅ z + μ \sigma\cdot z+\mu σ⋅z+μ。
这样就可以将 σ \sigma σ和 μ \mu μ也作为仿射网络的一部分,而不是不可导的环境参数。
这个技巧在VAE和Diffusion model中大量被使用。
x → z , q ϕ ( z ∣ x ) z → x , p θ ( x ∣ z ) x\to z,\quad q_{\phi}(z|x)\\ z\to x,\quad p_{\theta}(x|z) x→z,qϕ(z∣x)z→x,pθ(x∣z)
此时
x
x
x的边缘概率分布可以改写为关于z的期望式
p
(
x
)
=
∫
z
p
θ
(
x
∣
z
)
p
(
z
)
d
z
=
∫
z
q
ϕ
(
z
∣
x
)
p
θ
(
x
∣
z
)
p
(
z
)
q
ϕ
(
z
∣
x
)
d
z
=
E
z
∼
q
ϕ
(
z
∣
x
)
p
θ
(
x
∣
z
)
p
(
z
)
q
ϕ
(
z
∣
x
)
此时的Evidence存在一个lower bound(ELBO)
log
p
(
x
)
=
log
E
z
∼
q
ϕ
(
z
∣
x
)
p
θ
(
x
∣
z
)
p
(
z
)
q
ϕ
(
z
∣
x
)
≥
E
z
∼
q
ϕ
(
z
∣
x
)
log
[
p
θ
(
x
∣
z
)
p
(
z
)
q
ϕ
(
z
∣
x
)
]
\log p(x)=\log\mathbb{E}_{z\sim q_\phi(z|x)}\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} \ge\mathbb{E}_{z\sim q_\phi(z|x)}\log\left[\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)}\right]
logp(x)=logEz∼qϕ(z∣x)qϕ(z∣x)pθ(x∣z)p(z)≥Ez∼qϕ(z∣x)log[qϕ(z∣x)pθ(x∣z)p(z)]
在训练中,我们需要最大化对数似然,即Evidence,可以通过最小化lower bound实现,而这个lower bound可以分为两部分:
所以,单层VAE的损失函数是可优化的。
基于同样的原理,
p
(
x
)
=
∫
z
1
∫
z
2
p
θ
(
x
,
z
1
,
z
2
)
d
z
1
d
z
2
=
∫
z
1
∫
z
2
q
ϕ
(
z
1
,
z
2
∣
x
)
p
θ
(
x
,
z
1
,
z
2
)
q
ϕ
(
z
1
,
z
2
∣
x
)
d
z
1
d
z
2
=
E
z
1
,
z
2
∼
q
ϕ
(
z
1
,
z
2
∣
x
)
p
θ
(
x
,
z
1
,
z
2
)
q
ϕ
(
z
1
,
z
2
∣
x
)
得到
log
p
(
x
)
≥
E
z
1
,
z
2
∼
q
ϕ
(
z
1
,
z
2
∣
x
)
log
p
θ
(
x
,
z
1
,
z
2
)
q
ϕ
(
z
1
,
z
2
∣
x
)
\log p(x)\ge \mathbb{E}_{z1,z_2\sim q_\phi(z_1,z_2|x)}\log \frac{p_\theta(x,z_1,z_2)}{q_\phi(z_1,z_2|x)}
logp(x)≥Ez1,z2∼qϕ(z1,z2∣x)logqϕ(z1,z2∣x)pθ(x,z1,z2)
如果上述过程满足马尔科夫假设,即
p
θ
(
x
,
z
1
,
z
2
)
=
p
(
x
∣
z
1
)
p
(
z
1
∣
z
2
)
p
(
z
2
)
q
(
z
1
,
z
2
∣
x
)
=
q
(
z
1
∣
x
)
q
(
z
2
∣
z
1
)
p_\theta(x,z_1,z_2)=p(x|z_1)p(z_1|z_2)p(z_2)\\ q(z_1,z_2|x)=q(z_1|x)q(z_2|z_1)
pθ(x,z1,z2)=p(x∣z1)p(z1∣z2)p(z2)q(z1,z2∣x)=q(z1∣x)q(z2∣z1)
(6)式能够被进一步简化为
L
(
θ
,
ϕ
)
=
E
q
(
z
1
,
z
2
∣
x
)
[
log
p
(
x
∣
z
1
)
−
log
q
(
z
1
∣
x
)
+
log
p
(
z
1
∣
z
2
)
−
log
q
(
z
2
∣
z
1
)
+
log
p
(
z
2
)
]
\mathcal{L}(\theta,\phi)=\mathbb{E}_{q(z_1,z_2|x)} \left[ \log p(x|z_1)-\log q(z_1|x)+\log p(z_1|z_2) -\log q(z_2|z_1) +\log p(z_2) \right]
L(θ,ϕ)=Eq(z1,z2∣x)[logp(x∣z1)−logq(z1∣x)+logp(z1∣z2)−logq(z2∣z1)+logp(z2)]
从右往左,从目标分布到噪声分布称为扩散过程,而我们希望学习到从左往右的逆扩散过程。上图中的第一行从左往右是扩散过程,第二行从右往左是逆扩散过程,而第三行是前两者的差值,称为偏移量。
给定初始数据分布 x 0 ∼ q ( x ) \bold{x_0}\sim q(\bold{x}) x0∼q(x),不断向分布中添加高斯噪声,噪声的标准差是以 β t \beta_t βt确定的,均值是以固定值 β t \beta_t βt和当前时刻的数据 x t \bold{x_t} xt决定的,所以该过程并没有需要学习的参数,而且是一个马尔科夫链过程。
随着
t
t
t的不断增大,最终数据分布
x
T
x_T
xT变成了一个各项独立的高斯分布
q
(
x
t
∣
x
t
−
1
)
=
N
(
x
t
;
1
−
β
t
x
t
−
1
,
β
t
I
)
q(\bold{x_t|x_{t-1}})=\mathcal{N}(\bold{x_t};\sqrt{1-\beta_t}\bold{x_{t-1},\beta_t\bold{I}})
q(xt∣xt−1)=N(xt;1−βt
xt−1,βtI)
q ( x 1 : T ∣ x o ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(\bold{x_{1:T}|x_o})=\prod^{T}_{t=1}q(\bold{x_t|x_{t-1}}) q(x1:T∣xo)=t=1∏Tq(xt∣xt−1)
这充分体现了参数重整化的技巧。
任意时刻的 q ( x t ) q(\bold{x_t}) q(xt)推导也可以完全基于 x 0 \bold{x}_0 x0和 β t \beta_t βt计算得到闭式解,而不需要做迭代。(令 α t = 1 − β t \alpha_t=1-\beta_t αt=1−βt)
两个正态分布 X ∼ N ( μ 1 , σ 1 ) X\sim \mathcal{N}(\mu_1,\sigma_1) X∼N(μ1,σ1)和 Y ∼ N ( μ 2 , σ 2 ) Y\sim \mathcal{N}(\mu_2,\sigma_2) Y∼N(μ2,σ2)叠加后的分布 a X + b Y aX+bY aX+bY服从分布 N ( a μ 1 + b μ 2 , a 2 σ 1 2 + b 2 σ 2 2 ) \mathcal{N}(a\mu_1+b\mu_2,a^2\sigma_1^2+b^2\sigma_2^2) N(aμ1+bμ2,a2σ12+b2σ22)。
对于第
t
t
t步的分布
x
t
x_t
xt等于上一步的分布
x
t
−
1
x_{t-1}
xt−1加上高斯噪声
z
t
−
1
z_{t-1}
zt−1,即
x
t
=
α
t
x
t
−
1
+
1
−
α
t
z
t
−
1
;
where
z
t
−
1
,
z
t
−
2
,
.
.
.
∼
N
(
0
,
I
)
=
α
t
α
t
−
1
x
t
−
2
+
α
t
−
α
t
α
t
−
1
z
t
−
2
+
1
−
α
t
z
t
−
1
这里借助参数重整化的技巧,将红色部分的两个高斯分布合并为新的高斯分布,整理如下所示
x
t
=
α
t
α
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
z
ˉ
t
−
2
其中,
z
ˉ
t
−
2
∼
N
(
0
,
I
)
\bar{\bold{z}}_{t-2}\sim \mathcal{N}(\bold{0},\bold{I})
zˉt−2∼N(0,I)
重复上面的步骤,最终可以得到
z
t
\bold{z}_t
zt的闭式解
x
t
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
z
;
where
α
ˉ
t
=
∏
i
=
1
T
α
i
\bold{x}_t=\sqrt{\bar{\alpha}_t}\bold{x}_0+\sqrt{1-\bar{\alpha}_{t}}\bold{z}\qquad ;\text{where}\ \bar{\alpha}_t=\prod_{i=1}^T\alpha_i
xt=αˉt
x0+1−αˉt
z;where αˉt=i=1∏Tαi
此时,作者认为
x
t
∼
N
(
x
t
;
α
ˉ
t
x
0
,
1
−
α
ˉ
t
I
)
\bold{x}_t\sim \mathcal{N}(\bold{x}_t;\sqrt{\bar{\alpha}_t}\bold{x}_0,\sqrt{1-\bar{\alpha}_t}\bold{I})
xt∼N(xt;αˉt
x0,1−αˉt
I),(此处应该是认为
x
0
\bold{x}_0
x0是完全已知的,方差为零),最终当上述分布趋近于
N
(
0
,
I
)
\mathcal{N}(\bold{0},\bold{I})
N(0,I)的时候,可认为模型已经基本完成扩散过程。因此,作者给出了一种
β
t
\beta_t
βt的设置经验,
β
1
<
β
2
<
⋅
⋅
⋅
<
β
T
\beta_1<\beta_2<\cdot\cdot\cdot<\beta_T
β1<β2<⋅⋅⋅<βT,即随着扩散深度的加深,逐步扩大
β
\beta
β。
逆过程是从高斯分布中恢复原始数据,当
β
t
\beta_t
βt足够小时,逆过程的每一小步
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta(\bold{x}_{t-1}|\bold{x}_t)
pθ(xt−1∣xt)也可视作高斯分布,即
p
θ
(
x
t
−
1
∣
x
t
)
=
N
(
x
t
−
1
;
μ
θ
(
x
t
,
t
)
,
∑
θ
(
x
t
,
t
)
)
p_\theta(\bold{x}_{t-1}|\bold{x}_t)=\mathcal{N}(\bold{x}_{t-1};\bold{\mu_\theta}(\bold{x}_t,t),\sum_\theta(\bold{x}_t,t))
pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),θ∑(xt,t))
逆扩散过程可以被总结为如下形式
p
θ
(
x
0
:
T
)
=
p
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta(\bold{x}_{0:T})=p(\bold{x}_T)\prod_{t=1}^Tp_\theta (\bold{x}_{t-1}|\bold{x}_t)
pθ(x0:T)=p(xT)t=1∏Tpθ(xt−1∣xt)
此处通过使用网络估计参数
θ
\theta
θ以实现逆扩散过程。
根据条件概率的贝叶斯公式
q
(
x
t
−
1
∣
x
t
,
x
0
)
q
(
x
t
∣
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0)q(\bold{x}_t|\bold{x}_0)=q(\bold{x}_{t}|\bold{x}_{t-1},\bold{x}_0)q(\bold{x}_{t-1}|\bold{x}_0)
q(xt−1∣xt,x0)q(xt∣x0)=q(xt∣xt−1,x0)q(xt−1∣x0)
得到
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
∝
exp
(
−
1
2
(
(
x
t
−
α
t
x
t
−
1
)
2
1
−
α
t
+
(
x
t
−
1
−
α
ˉ
t
−
1
x
0
)
2
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
)
)
=
exp
(
−
1
2
(
a
x
t
−
1
2
+
b
x
t
−
1
+
c
(
x
t
,
x
0
)
)
)
可见,上述分布的核心可以用一个二次函数来描述,那对应的中轴线应该是
μ
~
t
(
x
t
,
x
0
)
=
−
b
2
a
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
x
0
容易从扩散过程的表达式(式11)得到
x
0
\bold{x}_0
x0的表达式
x
0
=
1
α
ˉ
t
(
x
t
−
1
−
α
ˉ
t
z
)
\bold{x}_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\bold{x}_t-\sqrt{1-\bar{\alpha}_{t}}\bold{z}\right)
x0=αˉt
1(xt−1−αˉt
z)
带入得到
μ
~
t
(
x
t
,
x
0
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
1
α
ˉ
t
(
x
t
−
1
−
α
ˉ
t
z
)
=
1
α
t
(
x
t
−
β
t
1
−
α
ˉ
t
z
t
)
这就是
x
t
−
1
\bold{x}_{t-1}
xt−1分布的均值表达式,即给定
x
0
\bold{x}_0
x0的条件下,后验条件高斯分布的均值计算只与
x
t
\bold{x}_{t}
xt和
z
t
\bold{z}_t
zt有关。
我们在待优化的目标数据分布的似然函数(负)上加一个非负的KL散度,构成负对数似然的上界,通过最小化上界,负对数似然自然取得最小。
−
log
p
θ
(
x
0
)
=
−
log
p
θ
(
x
0
)
+
D
K
L
(
q
(
x
1
:
T
∣
x
0
)
∣
∣
p
θ
(
x
1
:
T
∣
x
0
)
)
=
−
log
p
θ
(
x
0
)
+
E
x
1
:
T
∼
q
(
x
1
:
T
∣
x
0
)
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
/
p
θ
(
x
0
)
]
=
−
log
p
θ
(
x
0
)
+
E
x
1
:
T
∼
q
(
x
1
:
T
∣
x
0
)
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
+
log
p
θ
(
x
0
)
]
=
E
x
1
:
T
∼
q
(
x
1
:
T
∣
x
0
)
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
]
=
L
V
L
B
我们也可以继续对
L
V
L
B
L_{VLB}
LVLB进行展开,过程比较繁琐,建议查看论文,最终的形式如下
L
V
L
B
=
E
q
[
D
K
L
(
q
(
x
T
∣
x
0
)
∣
∣
p
θ
(
x
T
)
)
+
∑
t
=
1
T
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∣
∣
p
θ
(
x
t
−
1
∣
x
t
)
)
]
{\color{blue}{L_{VLB}}}=\mathbb{E}_q\left[D_{KL}\left(q(\bold{x}_T|\bold{x}_0)||p_\theta (\bold{x}_T)\right)+{\color{red} \sum_{t=1}^TD_{KL}(q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0)||p_\theta(\bold{x}_{t-1}|\bold{x}_t))}\right]
LVLB=Eq[DKL(q(xT∣x0)∣∣pθ(xT))+t=1∑TDKL(q(xt−1∣xt,x0)∣∣pθ(xt−1∣xt))]
其中第一项是不含待优化参数的,仅仅需要优化第二项即可。而且作者将
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta(\bold{x}_{t-1}|\bold{x}_t)
pθ(xt−1∣xt)的方差设置为与
β
\beta
β有关的常数,可训练参数仅存在其均值中。
我们已经知道
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0)
q(xt−1∣xt,x0)服从高斯分布,并给出了其均值的表达式,而且知道
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta(\bold{x}_{t-1}|\bold{x}_t)
pθ(xt−1∣xt)也服从高斯分布,其方差设置为常数,仅需优化均值即可。使用文章开头给出的两个单一变量的高斯分布的KL散度表达式,两个分布的方差均为常数,最终的损失函数可以写作两个分布的均值的关系:
L
t
−
1
=
E
q
[
1
2
σ
t
2
∣
∣
μ
~
t
(
x
t
,
x
0
)
−
μ
θ
(
x
t
,
t
)
∣
∣
2
]
+
C
{\color{red} L_{t-1}}=\mathbb{E}_q\left[\frac{1}{2\sigma_t^2}||\tilde{\bold\mu}_t(\bold{x}_t,\bold{x}_0)-\mu_\theta(\bold{x}_t,t)||^2\right]+C
Lt−1=Eq[2σt21∣∣μ~t(xt,x0)−μθ(xt,t)∣∣2]+C
我们可以将已经得到的
μ
t
\mu_t
μt的表达式,进行简化得到最终的损失函数:
L
simple
(
θ
)
:
=
E
t
,
x
0
,
ϵ
[
∣
∣
ϵ
−
ϵ
θ
(
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
,
t
)
∣
∣
2
]
L_{\text{simple}}(\theta):=\mathbb{E}_{t,\bold{x}_0,\bold\epsilon}\left[||\bold\epsilon-\bold\epsilon_\theta(\sqrt{\bar{\alpha}_t}\bold{x}_0+\sqrt{1-\bar{\alpha}_t}\bold\epsilon,t)||^2\right]
Lsimple(θ):=Et,x0,ϵ[∣∣ϵ−ϵθ(αˉt
x0+1−αˉt
ϵ,t)∣∣2]
这里,
ϵ
θ
\epsilon_\theta
ϵθ就是可学习的网络,输入
x
0
\bold{x}_0
x0和高斯噪声
ϵ
\epsilon
ϵ以及时刻
t
t
t。
优化好网络 ϵ θ \epsilon_\theta ϵθ之后,可以从 x T x_T xT逐步获得 x 0 x_0 x0
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。