当前位置:   article > 正文

DDPM代码详解(一)

ddpm代码

        在Diffusion原理(一),Diffusion原理(二)中分别介绍了Diffusion(DDPM)的前向过程和逆向过程原理。这里将进一步介绍一下DDPM的代码实现.

首先我们从整体来思考一下,DDPM的代码实现,会包含哪些部分。可以利用思维导图大致梳理一下:

接下来根据思维导图来逐个模块实现。

        1. 计算\beta, 我们定义一个函数,来获取每一步的\beta. 在DDPM中,\beta采用的是线性生成方式。每一步的\beta成线性增加的。输入参数timesteps是总的步数。由于默认总步数是1000步,但是为了适配不同的定义,可以自己指定总步数,因此这里就会对每一步的步长进行缩放。具体代码如下:

  1. def linear_beta_schedule(timesteps):
  2. totalstep = 1000
  3. value_start = 0.0001
  4. value_end = 0.02
  5. scale = totalstep / timesteps
  6. beta_start = scale * value_start
  7. beta_end = scale * value_end
  8. return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)

        2. 根据beta计算其他前向过程的参数。这里我们定义一个dGaussianDiffusion的类。先定义属性,利用这些属性(存放在register_buffer中)来设置这些固定的参数,不会参与参数更新。

  1. class GaussianDiffusion(nn.Module):
  2. def __init__(self,
  3. opts,
  4. device,
  5. network,
  6. min_snr_loss_weight=True):
  7. super().__init__()
  8. self.opts = opts
  9. self.device = device
  10. self.network = network.to(device)
  11. # define betas: beta1, beta2, ... beta_n
  12. beta_schedule_fn = linear_beta_schedule
  13. betas = beta_schedule_fn(self.opts['timesteps'])
  14. self.num_timesteps = int(betas.shape[0])
  15. # define alphas
  16. # get a1, a2, ..., an
  17. alphas = 1.0 - betas
  18. alphas_cumprod = torch.cumprod(alphas, dim=0)
  19. alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
  20. sqrt_recip_alphas = 1.0 / torch.sqrt(alphas)
  21. register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
  22. register_buffer('betas', betas)
  23. register_buffer('alphas_cumprod', alphas_cumprod)
  24. register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
  25. register_buffer('sqrt_recip_alphas', sqrt_recip_alphas)
  26. # calculations for diffusion q(x_t | x_{t-1}) and others
  27. # x_t = sqrt(alphas_cumprod)* x_0 + sqrt(1 - alphas_cumprod) * noise
  28. register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
  29. register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
  30. # calculations for posterior q(x_{t - 1} | x_t, x_0)
  31. posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
  32. register_buffer('posterior_variance', posterior_variance)
  33. register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod)) # A
  34. register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod - 1.0)) # B
  35. # mu_{t - 1} = mean_coef1 * clip(x_{0}) + mean_coef2 * x_{t}
  36. register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
  37. register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
  38. register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) /
  39. (1.0 - alphas_cumprod))
  40. snr = alphas_cumprod / (1.0 - alphas_cumprod)
  41. maybe_clipped_snr =snr.clone()
  42. if min_snr_loss_weight:
  43. maybe_clipped_snr.clamp_(max=self.opts['min_snr_gamma'])
  44. register_buffer('loss_weight', maybe_clipped_snr / snr)
  45. self.ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])

        3. 接下来是正向过程:这里完全按照原理部分的公式来获取对应前向过程t时刻的样本。

  1. def q_sample(self, x_start, t, noise=None):
  2. sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
  3. sqrt_one_minus_alphas_cumprod_t = extract(
  4. self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
  5. )
  6. return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

        其中extract的实现如下,这个代表根据t来抽取对应位置的值。

  1. def extract(tensor, t, x_shape):
  2. batch_size =t.shape[0]
  3. out = tensor.gather(-1, t.cpu())
  4. return out.reshape(batch_size,
  5. *((1, ) * (len(x_shape) - 1))).to(t.device)

        比如 sqrt_alphas_cumprod_t 就是X_t = \sqrt{\alpha_t\alpha_{t-1}...\alpha_{1}}X_{0} + \sqrt{1 - \alpha_t\alpha_{t-1}...\alpha_{1}}\epsilon'_{0} 中的\sqrt{\alpha_t\alpha_{t-1}...\alpha_{1}}. sqrt_alphas_cumprod其实是所有从0-T的这些alpha在每个时刻都保存的一个tensor。

4. 训练时求loss。前面已经知道了前向过程,loss计算只拟合噪声,这个时候我们就可以得到直接计算loss了。在DDPM原理介绍中,我们最后推到出了整个loss其实可以表达为公式:

这里就会要求,求得每一步的噪声误差的总和。但是呢,为了如果每一步都去求,那么未免太耗时了。比如说总步长是1000, 那么每一个样本求1000次的loss总和未免太耗费训练时间了。这里为了简便,会对每个训练batch中的每张图,对应在[1, 1000]中随机生成一个步长timestep。比如一个batch有四个样本,每个样本的步长在[1, 1000]中随机生成,比如可以是[50, 100, 94, 786]。具体代码如下所示。

  1. def p_losses(self, x_start, t, noise=None):
  2. noise = default(noise, lambda: torch.randn_like(x_start))
  3. x_t = self.q_sample(x_start=x_start, t=t, noise=noise)
  4. network_out = self.network(x_t, t)
  5. target = noise
  6. if self.opts['loss_type'] == 'huber':
  7. loss = F.smooth_l1_loss(network_out, target, reduction='none')
  8. elif self.opts['loss_type'] == 'l1':
  9. loss = F.l1_loss(network_out, target, reduction='none')
  10. elif self.opts['loss_type'] == 'l2':
  11. loss = F.mse_loss(network_out, target, reduction='none')
  12. else:
  13. raise NotImplementedError()
  14. loss = reduce(loss, 'b ... -> b (...)', 'mean')
  15. loss = loss * extract(self.loss_weight, t, loss.shape)
  16. return loss.mean()
  17. def forward(self, img):
  18. b, _, _, _ = img.shape
  19. t = torch.randint(0, self.num_timesteps, (b,), device=self.device).long()
  20. return self.p_losses(img, t)

        5. 接下来要实现逆向生成过程:逆向生成过程也是按照前面原理部分的公式得到的,只是把数学公式用代码表达而已。这里的stable_sampling只是为了更好实现逆向生成,做了一些数学公式上的变换。

  1. @torch.inference_mode()
  2. def p_sample(self, x, t, t_index):
  3. betas_t = extract(self.betas, t, x.shape)
  4. sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
  5. sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
  6. x_mean = sqrt_recip_alphas_t * (
  7. x - betas_t * self.network(x, t) / sqrt_one_minus_alphas_cumprod_t
  8. )
  9. if t_index == 0:
  10. return x_mean
  11. else:
  12. posterior_variance_t = extract(self.posteroir_variance, t, x.shape)
  13. noise = torch.randn_like(x)
  14. return x_mean + torch.sqrt(posterior_variance_t) * noise

        上面的p_sample只是一步的逆向生成,要实现从Xt到X0的生成,需要一个循环,如下所示:

  1. @torch.inference_mode()
  2. def p_sample_loop(self, shape, return_all_timesteps=False):
  3. batch_size = self.opts['sample_batch_size']
  4. image = torch.randn(shape, device=self.device)
  5. return_images = [image.cpu().numpy()]
  6. for i in tqdm(reversed(range(0, self.opts['timesteps'])),
  7. desc='sampling loop time step',
  8. total=self.opts['timesteps']):
  9. image = self.p_sample(image, torch.full((batch_size, ), i,
  10. device=self.device, dtype=torch.long), i)
  11. if return_all_timesteps:
  12. return_images.append(image.cpu().numpy())
  13. else:
  14. if i == 0:
  15. return_images.append(image.cpu().numpy())
  16. return return_images

        至此,原始DDPM的代码已经实现了,还差一个生成噪声的network的定义,之后将在下一次代码详解中介绍。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/624600
推荐阅读
相关标签
  

闽ICP备14008679号