赞
踩
在Diffusion原理(一),Diffusion原理(二)中分别介绍了Diffusion(DDPM)的前向过程和逆向过程原理。这里将进一步介绍一下DDPM的代码实现.
首先我们从整体来思考一下,DDPM的代码实现,会包含哪些部分。可以利用思维导图大致梳理一下:
接下来根据思维导图来逐个模块实现。
1. 计算, 我们定义一个函数,来获取每一步的. 在DDPM中,采用的是线性生成方式。每一步的成线性增加的。输入参数timesteps是总的步数。由于默认总步数是1000步,但是为了适配不同的定义,可以自己指定总步数,因此这里就会对每一步的步长进行缩放。具体代码如下:
- def linear_beta_schedule(timesteps):
- totalstep = 1000
- value_start = 0.0001
- value_end = 0.02
- scale = totalstep / timesteps
- beta_start = scale * value_start
- beta_end = scale * value_end
- return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
2. 根据beta计算其他前向过程的参数。这里我们定义一个dGaussianDiffusion的类。先定义属性,利用这些属性(存放在register_buffer中)来设置这些固定的参数,不会参与参数更新。
- class GaussianDiffusion(nn.Module):
- def __init__(self,
- opts,
- device,
- network,
- min_snr_loss_weight=True):
- super().__init__()
-
- self.opts = opts
- self.device = device
-
- self.network = network.to(device)
-
- # define betas: beta1, beta2, ... beta_n
- beta_schedule_fn = linear_beta_schedule
-
-
- betas = beta_schedule_fn(self.opts['timesteps'])
- self.num_timesteps = int(betas.shape[0])
-
- # define alphas
- # get a1, a2, ..., an
- alphas = 1.0 - betas
-
- alphas_cumprod = torch.cumprod(alphas, dim=0)
- alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
- sqrt_recip_alphas = 1.0 / torch.sqrt(alphas)
-
- register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
-
- register_buffer('betas', betas)
- register_buffer('alphas_cumprod', alphas_cumprod)
- register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
- register_buffer('sqrt_recip_alphas', sqrt_recip_alphas)
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- # x_t = sqrt(alphas_cumprod)* x_0 + sqrt(1 - alphas_cumprod) * noise
- register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
- register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
-
- # calculations for posterior q(x_{t - 1} | x_t, x_0)
-
- posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
- register_buffer('posterior_variance', posterior_variance)
-
- register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod)) # A
- register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod - 1.0)) # B
-
- # mu_{t - 1} = mean_coef1 * clip(x_{0}) + mean_coef2 * x_{t}
-
- register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
- register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
- register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) /
- (1.0 - alphas_cumprod))
-
- snr = alphas_cumprod / (1.0 - alphas_cumprod)
- maybe_clipped_snr =snr.clone()
- if min_snr_loss_weight:
- maybe_clipped_snr.clamp_(max=self.opts['min_snr_gamma'])
-
- register_buffer('loss_weight', maybe_clipped_snr / snr)
-
- self.ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
3. 接下来是正向过程:这里完全按照原理部分的公式来获取对应前向过程t时刻的样本。
- def q_sample(self, x_start, t, noise=None):
- sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
- sqrt_one_minus_alphas_cumprod_t = extract(
- self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
- )
-
- return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
其中extract的实现如下,这个代表根据t来抽取对应位置的值。
- def extract(tensor, t, x_shape):
- batch_size =t.shape[0]
- out = tensor.gather(-1, t.cpu())
- return out.reshape(batch_size,
- *((1, ) * (len(x_shape) - 1))).to(t.device)
比如 sqrt_alphas_cumprod_t 就是 中的. sqrt_alphas_cumprod其实是所有从0-T的这些alpha在每个时刻都保存的一个tensor。
4. 训练时求loss。前面已经知道了前向过程,loss计算只拟合噪声,这个时候我们就可以得到直接计算loss了。在DDPM原理介绍中,我们最后推到出了整个loss其实可以表达为公式:
这里就会要求,求得每一步的噪声误差的总和。但是呢,为了如果每一步都去求,那么未免太耗时了。比如说总步长是1000, 那么每一个样本求1000次的loss总和未免太耗费训练时间了。这里为了简便,会对每个训练batch中的每张图,对应在[1, 1000]中随机生成一个步长timestep。比如一个batch有四个样本,每个样本的步长在[1, 1000]中随机生成,比如可以是[50, 100, 94, 786]。具体代码如下所示。
- def p_losses(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- x_t = self.q_sample(x_start=x_start, t=t, noise=noise)
- network_out = self.network(x_t, t)
- target = noise
-
- if self.opts['loss_type'] == 'huber':
- loss = F.smooth_l1_loss(network_out, target, reduction='none')
- elif self.opts['loss_type'] == 'l1':
- loss = F.l1_loss(network_out, target, reduction='none')
- elif self.opts['loss_type'] == 'l2':
- loss = F.mse_loss(network_out, target, reduction='none')
- else:
- raise NotImplementedError()
-
- loss = reduce(loss, 'b ... -> b (...)', 'mean')
- loss = loss * extract(self.loss_weight, t, loss.shape)
- return loss.mean()
-
- def forward(self, img):
- b, _, _, _ = img.shape
- t = torch.randint(0, self.num_timesteps, (b,), device=self.device).long()
- return self.p_losses(img, t)
5. 接下来要实现逆向生成过程:逆向生成过程也是按照前面原理部分的公式得到的,只是把数学公式用代码表达而已。这里的stable_sampling只是为了更好实现逆向生成,做了一些数学公式上的变换。
- @torch.inference_mode()
- def p_sample(self, x, t, t_index):
- betas_t = extract(self.betas, t, x.shape)
- sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
- sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
- x_mean = sqrt_recip_alphas_t * (
- x - betas_t * self.network(x, t) / sqrt_one_minus_alphas_cumprod_t
- )
-
- if t_index == 0:
- return x_mean
- else:
- posterior_variance_t = extract(self.posteroir_variance, t, x.shape)
- noise = torch.randn_like(x)
- return x_mean + torch.sqrt(posterior_variance_t) * noise
上面的p_sample只是一步的逆向生成,要实现从Xt到X0的生成,需要一个循环,如下所示:
- @torch.inference_mode()
- def p_sample_loop(self, shape, return_all_timesteps=False):
- batch_size = self.opts['sample_batch_size']
- image = torch.randn(shape, device=self.device)
- return_images = [image.cpu().numpy()]
-
- for i in tqdm(reversed(range(0, self.opts['timesteps'])),
- desc='sampling loop time step',
- total=self.opts['timesteps']):
- image = self.p_sample(image, torch.full((batch_size, ), i,
- device=self.device, dtype=torch.long), i)
- if return_all_timesteps:
- return_images.append(image.cpu().numpy())
- else:
- if i == 0:
- return_images.append(image.cpu().numpy())
-
- return return_images
至此,原始DDPM的代码已经实现了,还差一个生成噪声的network的定义,之后将在下一次代码详解中介绍。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。