赞
踩
我们还不知道
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
pθ(xt−1∣xt)是什么形式,扩散模型的第一篇文章给出其同样也服从某个高斯分布,这个好像是从热动力学那里得到证明的,不做深入解释,我们现在要求解的就是其服从的分布的均值和方差是什么,才能够满足将损失函数最小化的要求,原文中给出的
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
pθ(xt−1∣xt)的形式为:
p
θ
(
x
t
−
1
∣
x
t
)
=
N
(
x
t
−
1
;
μ
θ
(
x
t
,
t
)
,
Σ
θ
(
x
t
,
t
)
)
p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right), \mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)\right)
pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
来看损失函数的第二项
∑
t
=
2
T
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
\sum_{t=2}^T D_{K L}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)
∑t=2TDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt)),为了方便,用
L
t
L_t
Lt表示,两个高斯分布计算的KL散度为两个分布均值的L2损失(前面有个系数),这个已经被证明过了,并且很容易推导出来,在这里就不推了,我们将第二项的散度展开之后应该是:
L
t
=
E
x
0
,
ϵ
[
1
2
Σ
θ
(
x
t
,
t
)
2
∥
μ
~
t
(
x
t
,
x
0
)
−
μ
θ
(
x
t
,
t
)
∥
2
]
对于
μ
~
(
x
t
,
x
0
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
\tilde{\mu}\left(\mathbf{x}_t, \mathbf{x}_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0
μ~(xt,x0)=1−αˉtαt
(1−αˉt−1)xt+1−αˉtαˉt−1
βtx0,我们将它表示成只有
x
t
\mathbf{x}_t
xt的形式,根据前向过程推导的
x
t
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
\mathbf{x}_t=\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol\epsilon
xt=αˉt
x0+1−αˉt
ϵ,带入可以得到
μ
~
(
x
t
,
x
0
)
=
1
α
t
(
x
t
−
1
−
α
t
1
−
α
ˉ
t
ϵ
t
)
\tilde{\mu}\left(\mathbf{x}_t, \mathbf{x}_0\right)=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol\epsilon_t\right)
μ~(xt,x0)=αt
1(xt−1−αˉt
1−αtϵt),相应地,
μ
θ
(
x
t
,
t
)
\mu_\theta\left(\mathbf{x}_t, t\right)
μθ(xt,t)可以表示为
μ
θ
(
x
t
,
t
)
=
1
α
t
(
x
t
−
1
−
α
t
1
−
α
ˉ
t
ϵ
θ
(
x
t
,
t
)
)
\mu_\theta\left(\mathbf{x}_t, t\right)=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol\epsilon_\theta\left(\mathbf{x}_t, t\right)\right)
μθ(xt,t)=αt
1(xt−1−αˉt
1−αtϵθ(xt,t)),其中
ϵ
t
\boldsymbol\epsilon_t
ϵt表示前向过程的
t
t
t时刻添加的
ϵ
∼
N
(
0
,
1
)
\epsilon \sim \mathcal{N}(0, 1)
ϵ∼N(0,1)的具体噪声,也就是实际的采样值。上式变为:
L
t
=
E
x
0
,
ϵ
[
1
2
Σ
θ
(
x
t
,
t
)
2
∥
μ
~
t
(
x
t
,
x
0
)
−
μ
θ
(
x
t
,
t
)
∥
2
]
=
E
x
0
,
ϵ
[
1
2
∥
Σ
θ
∥
2
2
∥
1
α
t
(
x
t
−
1
−
α
t
1
−
α
ˉ
t
ϵ
t
)
−
1
α
t
(
x
t
−
1
−
α
t
1
−
α
ˉ
t
ϵ
θ
(
x
t
,
t
)
)
∥
2
]
=
E
x
0
,
ϵ
[
(
1
−
α
t
)
2
2
α
t
(
1
−
α
ˉ
t
)
∥
Σ
θ
∥
2
2
∥
ϵ
t
−
ϵ
θ
(
x
t
,
t
)
∥
2
]
ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta\left(\mathbf{x}_t, t\right) ϵθ(xt,t)表示要用神经网络预测的值,具体来说, θ \theta θ本身作为网络的参数, ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta\left(\mathbf{x}_t, t\right) ϵθ(xt,t)作为网络的预测值(输出),所以在实际训练时,我们只需要预测在不同的 t t t时刻所添加的噪声,并与真实的噪声 ϵ t \boldsymbol\epsilon_t ϵt计算L2损失,就可以不断地减小 L t L_t Lt,从而达到一开始最大化 log p ( x 0 ) \log p(\mathbf{x}_0) logp(x0)的目标。
这里有一个地方, Σ θ \boldsymbol{\Sigma}_\theta Σθ被设置为固定值,所以它可以提出到前面的常数项中,openAI在《Improved Denoising Diffusion Probabilistic Models》的文章中对这一设定进行了修改,将其变成与参数有关的值,因此 L t L_t Lt公式有一些改动,但是本质思想不变,感兴趣的可以自己试验一下。
在2020年的《Denoising Diffusion Probabilistic Models》这篇文章中,作者在实验中发现,对于
L
t
L_t
Lt的优化如果直接省去前面的权重项会更有利于训练:
L
t
=
E
x
0
,
ϵ
[
∥
ϵ
t
−
ϵ
θ
(
x
t
,
t
)
∥
2
]
接下来,我们终于可以返回到原来的优化目标:
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
=
log
(
q
(
x
T
∣
x
0
)
p
(
x
T
)
)
+
∑
t
=
2
T
log
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
)
−
log
(
p
θ
(
x
0
∣
x
1
)
)
≡
D
K
L
(
q
(
x
T
∣
x
0
)
∥
p
(
x
T
)
)
+
∑
t
=
2
T
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
−
log
(
p
θ
(
x
0
∣
x
1
)
)
=
constant
+
E
x
0
,
ϵ
[
∥
ϵ
t
−
ϵ
θ
(
x
t
,
t
)
∥
2
]
−
log
(
p
θ
(
x
0
∣
x
1
)
)
还剩下后面这一项
−
log
(
p
θ
(
x
0
∣
x
1
)
)
-\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right)
−log(pθ(x0∣x1)),论文中用了另外的一个神经网络(原文称encoder)来预测
t
0
t_0
t0时刻图像而非噪声(预测
t
0
t_0
t0需要
t
1
t_1
t1的知识):
p
θ
(
x
0
∣
x
1
)
=
∏
i
=
1
D
∫
δ
−
(
x
0
i
)
δ
+
(
x
0
i
)
N
(
x
;
μ
θ
i
(
x
1
,
1
)
,
σ
1
2
)
d
x
δ
+
(
x
)
=
{
∞
if
x
=
1
x
+
1
255
if
x
<
1
δ
−
(
x
)
=
{
−
∞
if
x
=
−
1
x
−
1
255
if
x
>
−
1
但是在简化的版本中,作者将上式包括在了前面的
L
t
L_t
Lt中,对应
t
=
1
t=1
t=1,所以最终的优化目标为:
L
=
constant
+
E
x
0
,
ϵ
[
∥
ϵ
t
−
ϵ
θ
(
x
t
,
t
)
∥
2
]
\mathcal{L}=\text{constant} + \mathbb{E}_{\mathbf{x}_0, \boldsymbol\epsilon}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right]
L=constant+Ex0,ϵ[∥ϵt−ϵθ(xt,t)∥2]
也就是说,我们只需要预测每个时刻添加的噪声就可以了。
1. 前向扩散过程
# forward process import torch.nn.functional as F def linear_beta_schedule(timesteps, start=0.0001, end=0.02): return torch.linspace(start, end, timesteps) def get_index_from_list(vals, t, x_shape): """ Returns a specific index t of a passed list of values vals while considering the batch dimension. """ batch_size = t.shape[0] out = vals.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) def forward_diffusion_sample(x_0, t, device="cpu"): """ Takes an image and a timestep as input and returns the noisy version of it """ noise = torch.randn_like(x_0) sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape) sqrt_one_minus_alphas_cumprod_t = get_index_from_list( sqrt_one_minus_alphas_cumprod, t, x_0.shape ) # mean + variance return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \ + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)
2. 提前计算 α \alpha α和 β \beta β等参数
# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)
# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
3. 加载一个数据集并测试一下前向过程
# test on the car dataset from torchvision import transforms from torch.utils.data import DataLoader import numpy as np IMG_SIZE = 64 BATCH_SIZE = 16 def load_transformed_dataset(): data_transforms = [ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # Scales data into [0,1] transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] ] data_transform = transforms.Compose(data_transforms) train = torchvision.datasets.StanfordCars(root=".", download=True, transform=data_transform) test = torchvision.datasets.StanfordCars(root=".", download=True, transform=data_transform, split='test') return torch.utils.data.ConcatDataset([train, test]) def show_tensor_image(image): reverse_transforms = transforms.Compose([ transforms.Lambda(lambda t: (t + 1) / 2), transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC transforms.Lambda(lambda t: t * 255.), transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), transforms.ToPILImage(), ]) # Take first image of batch if len(image.shape) == 4: image = image[0, :, :, :] plt.imshow(reverse_transforms(image)) data = load_transformed_dataset() dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) # Simulate forward diffusion 可忽略 image = next(iter(dataloader))[0] plt.figure(figsize=(15,15)) plt.axis('off') num_images = 10 stepsize = int(T/num_images) for idx in range(0, T, stepsize): t = torch.Tensor([idx]).type(torch.int64) plt.subplot(1, num_images+1, int(idx/stepsize) + 1) image, noise = forward_diffusion_sample(image, t) show_tensor_image(image) plt.show()
4. 定义损失函数
# get loss
def get_loss(model, x_0, t):
x_noisy, noise = forward_diffusion_sample(x_0, t, device)
noise_pred = model(x_noisy, t)
return F.l1_loss(noise, noise_pred)
5. 采样(预测)阶段,它的预测需要不断地迭代反向过程,所以很消耗计算量
@torch.no_grad() def sample_timestep(x, t): """ Calls the model to predict the noise in the image and returns the denoised image. Applies noise to this image, if we are not in the last step yet. """ betas_t = get_index_from_list(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = get_index_from_list( sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape) # Call model (current image - noise prediction) model_mean = sqrt_recip_alphas_t * ( x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t ) posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape) if t == 0: return model_mean else: noise = torch.randn_like(x) return model_mean + torch.sqrt(posterior_variance_t) * noise @torch.no_grad() def sample_plot_image(): # Sample noise img_size = IMG_SIZE img = torch.randn((1, 3, img_size, img_size), device=device) plt.figure(figsize=(15, 15)) plt.axis('off') num_images = 10 stepsize = int(T / num_images) # 从 T = 200的时刻开始迭代,直到迭代到 t = 0时刻 # 但是绘图的时候只绘制几张 for i in range(0, T)[::-1]: t = torch.full((1,), i, device=device, dtype=torch.long) img = sample_timestep(img, t) if i % stepsize == 0: plt.subplot(1, int(num_images), int(i / stepsize) + 1) show_tensor_image(img.detach().cpu()) plt.show() from torchvision.utils import save_image @torch.no_grad() def save_sampled_image(epoch): # Sample noise img_size = IMG_SIZE img = torch.randn((1, 3, img_size, img_size), device=device) num_images = 10 stepsize = int(T / num_images) # 从 T = 200的时刻开始迭代,直到迭代到 t = 0时刻 # 但是绘图的时候只绘制几张 for i in range(0, T)[::-1]: t = torch.full((1,), i, device=device, dtype=torch.long) img = sample_timestep(img, t) if i % stepsize == 0: trans = transforms.Lambda(lambda t: (t + 1) / 2) if len(img.shape) == 4: image = img[0, :, :, :] image = trans(image) save_image(image, './results/' + str(epoch) + '_' + str(i) + '.jpg')
6. 训练过程,这里把测试写在上面是因为预测 x 0 x_0 x0本身就需要对反向过程迭代采样
from torch.optim import Adam from implementation2.model import SimpleUnet model = SimpleUnet() device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) optimizer = Adam(model.parameters(), lr=0.001) epochs = 100 # Try more! for epoch in range(epochs): for step, batch in enumerate(dataloader): optimizer.zero_grad() t = torch.randint(0, T, (BATCH_SIZE,), device=device).long() loss = get_loss(model, batch[0], t) loss.backward() optimizer.step() # 每隔5个epoch测试一下当前的模型 if epoch % 5 == 0 and step == 0: print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ") # sample_plot_image() save_sampled_image(epoch)
7. 模型(transformer用来编码时间信息,u-net用来编码图像)
from torch import nn import math import torch class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim, up=False): super().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) if up: self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1) self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) else: self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.bnorm1 = nn.BatchNorm2d(out_ch) self.bnorm2 = nn.BatchNorm2d(out_ch) self.relu = nn.ReLU() def forward(self, x, t, ): # First Conv h = self.bnorm1(self.relu(self.conv1(x))) # Time embedding time_emb = self.relu(self.time_mlp(t)) # Extend last 2 dimensions time_emb = time_emb[(...,) + (None,) * 2] # Add time channel h = h + time_emb # Second Conv h = self.bnorm2(self.relu(self.conv2(h))) # Down or Upsample return self.transform(h) class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) # TODO: Double check the ordering here return embeddings class SimpleUnet(nn.Module): """ A simplified variant of the Unet architecture. """ def __init__(self): super().__init__() image_channels = 3 down_channels = (64, 128, 256, 512, 1024) up_channels = (1024, 512, 256, 128, 64) out_dim = 1 time_emb_dim = 32 # Time embedding self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.ReLU() ) # Initial projection self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) # Downsample self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], \ time_emb_dim) \ for i in range(len(down_channels) - 1)]) # Upsample self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], \ time_emb_dim, up=True) \ for i in range(len(up_channels) - 1)]) self.output = nn.Conv2d(up_channels[-1], 3, out_dim) def forward(self, x, timestep): # Embedd time t = self.time_mlp(timestep) # Initial conv x = self.conv0(x) # Unet residual_inputs = [] for down in self.downs: x = down(x, t) residual_inputs.append(x) for up in self.ups: residual_x = residual_inputs.pop() # Add residual x as additional channels x = torch.cat((x, residual_x), dim=1) x = up(x, t) return self.output(x)
这两篇只是一些最基础的理论,个人感觉了解事情的来龙去脉还是很重要的,了解事物的出发点是什么,后面不管是利用模型也好还是利用模型的思想也好,都有助于在思考问题时更加深刻与灵活。
当年GAN刚兴起的时候,YouTube上有个热门评论称GAN是过去20年来深度学习中最酷的想法,但是后面的研究逐渐发现了GAN存在的很多问题,这项思想逐渐变得不像它刚刚兴起时那样完美,不知道扩散模型后面会不会也走相同的路线
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。