赞
踩
来源:超详细的扩散模型(Diffusion Models)原理+代码 - 知乎 (zhihu.com)
代码:drizzlezyk/DDPM-MindSpore (github.com)
1.1 正弦位置编码
- class SinusoidalPosEmb(nn.Cell):
- def __init__(self, dim):
- super().__init__()
- half_dim = dim // 2 # 将给定的维度除以2得到半维度
- emb = math.log(10000) / (half_dim - 1) # 计算位置编码的参数
- emb = np.exp(np.arange(half_dim) * -emb) # 根据半维度创建正弦位置编码矩阵
- self.emb = Tensor(emb, mindspore.float32) # 将矩阵转换为Tensor,并存储在类属性中
- self.Concat = _get_cache_prim(ops.Concat)(-1) # 定义连接操作
-
- def construct(self, x):
- emb = x[:, None] * self.emb[None, :] # 对输入张量进行位置编码计算
- emb = self.Concat((ops.sin(emb), ops.cos(emb))) # 将正弦和余弦编码连接起来
- return emb
1.2 Attention
- class Attention(nn.Cell):
- def __init__(self, dim, heads=4, dim_head=32):
- super().__init__()
- self.scale = dim_head ** -0.5 # 缩放因子,用于调整注意力权重
- self.heads = heads # 头数,决定了注意力机制的复杂度
- hidden_dim = dim_head * heads # 隐藏维度,等于头数乘以每个头的维度
-
- # 将输入进行线性变换得到查询、键和值
- self.to_qkv = _get_cache_prim(Conv2d)(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
- # 将注意力加权后的结果再进行线性变换得到最终的输出
- self.to_out = _get_cache_prim(Conv2d)(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
-
- self.map = ops.Map()
- self.partial = ops.Partial()
- self.bmm = BMM() # 矩阵相乘操作
- self.split = ops.Split(axis=1, output_num=3) # 在指定维度上将张量分割成多个部分
- self.softmax = ops.Softmax(-1) # 对最后一维进行softmax操作
-
- def construct(self, x):
- b, c, h, w = x.shape # 获取输入张量的形状信息
- qkv = self.split(self.to_qkv(x)) # 将输入进行线性变换得到查询、键和值,并将其分割成三个部分
- q, k, v = self.map(self.partial(rearrange, self.heads), qkv) # 对查询、键和值进行重排操作,以便多头注意力机制的计算
- q = q * self.scale # 缩放查询向量
-
- sim = self.bmm(q.swapaxes(2, 3), k) # 计算查询和键的相似度,使用矩阵乘法实现并行计算
- attn = self.softmax(sim) # 使用softmax函数对相似度进行归一化,得到注意力权重
- out = self.bmm(attn, v.swapaxes(2, 3)) # 将注意力权重与值相乘,得到加权后的结果
- out = out.swapaxes(-1, -2).reshape((b, -1, h, w)) # 将结果进行维度转换和形状调整,得到最终的输出
- return self.to_out(out) # 将输出进行线性变换得到最终的注意力机制输出
1.3 Residual Block
- class Residual(nn.Cell):
- """残差块"""
- def __init__(self, fn):
- super().__init__()
- self.fn = fn # 残差块的函数或模型
-
- def construct(self, x, *args, **kwargs):
- return self.fn(x, *args, **kwargs) + x # 将输入与通过函数或模型处理后的结果相加作为输出
定义相关的概率值,与公式相对应:
- self.betas = betas # 初始化betas参数
- self.alphas_cumprod = alphas_cumprod # 初始化alphas_cumprod参数
- self.alphas_cumprod_prev = alphas_cumprod_prev # 初始化alphas_cumprod_prev参数
-
- # 计算扩散 q(x_t | x_{t-1}) 和其他参数
- self.sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod)) # 计算alphas_cumprod的平方根
- self.sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod)) # 计算1 - alphas_cumprod的平方根
- self.log_one_minus_alphas_cumprod = Tensor(np.log(1. - alphas_cumprod)) # 计算log(1 - alphas_cumprod)
- self.sqrt_recip_alphas_cumprod = Tensor(np.sqrt(1. / alphas_cumprod)) # 计算1 / alphas_cumprod的平方根
- self.sqrt_recipm1_alphas_cumprod = Tensor(np.sqrt(1. / alphas_cumprod - 1)) # 计算1 / alphas_cumprod - 1的平方根
-
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # 计算后验方差
-
- self.posterior_variance = Tensor(posterior_variance) # 存储后验方差
- self.posterior_log_variance_clipped = Tensor(
- np.log(np.clip(posterior_variance, 1e-20, None))) # 计算后验方差的对数,并进行截断
- self.posterior_mean_coef1 = Tensor(
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) # 计算后验均值的系数1
- self.posterior_mean_coef2 = Tensor(
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)) # 计算后验均值的系数2
-
- p2_loss_weight = (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** - p2_loss_weight_gamma # 计算p2_loss_weight
- self.p2_loss_weight = Tensor(p2_loss_weight) # 存储p2_loss_weight参数
计算损失:
基于Unet预测出noise,使用预测noise和真实noise计算损失:
- self.betas = betas # 初始化betas参数
- self.alphas_cumprod = alphas_cumprod # 初始化alphas_cumprod参数
- self.alphas_cumprod_prev = alphas_cumprod_prev # 初始化alphas_cumprod_prev参数
-
- # 计算扩散 q(x_t | x_{t-1}) 和其他参数
- self.sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod)) # 计算alphas_cumprod的平方根
- self.sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod)) # 计算1 - alphas_cumprod的平方根
- self.log_one_minus_alphas_cumprod = Tensor(np.log(1. - alphas_cumprod)) # 计算log(1 - alphas_cumprod)
- self.sqrt_recip_alphas_cumprod = Tensor(np.sqrt(1. / alphas_cumprod)) # 计算1 / alphas_cumprod的平方根
- self.sqrt_recipm1_alphas_cumprod = Tensor(np.sqrt(1. / alphas_cumprod - 1)) # 计算1 / alphas_cumprod - 1的平方根
-
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # 计算后验方差
-
- self.posterior_variance = Tensor(posterior_variance) # 存储后验方差
- self.posterior_log_variance_clipped = Tensor(
- np.log(np.clip(posterior_variance, 1e-20, None))) # 计算后验方差的对数,并进行截断
- self.posterior_mean_coef1 = Tensor(
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) # 计算后验均值的系数1
- self.posterior_mean_coef2 = Tensor(
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)) # 计算后验均值的系数2
-
- p2_loss_weight = (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** - p2_loss_weight_gamma # 计算p2_loss_weight
- self.p2_loss_weight = Tensor(p2_loss_weight) # 存储p2_loss_weight参数
采样:
输出x_start,也就是原始图像,当sampling_time_steps< time_steps,用下方函数:
- def ddim_sample(self, shape, clip_denoise=True):
- batch = shape[0]
- total_timesteps, sampling_timesteps, = self.num_timesteps, self.sampling_timesteps
- eta, objective = self.ddim_sampling_eta, self.objective
-
- # 创建采样时间步列表,[-1, 0, 1, 2, ..., T-1],当sampling_timesteps == total_timesteps时
- times = np.linspace(-1, total_timesteps - 1, sampling_timesteps + 1).astype(np.int32)
- # 创建时间对列表,[(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
- times = list(reversed(times.tolist()))
- time_pairs = list(zip(times[:-1], times[1:]))
-
- img = np.random.randn(*shape).astype(np.float32) # 随机初始化图像
- x_start = None
-
- for time, time_next in tqdm(time_pairs, desc='sampling loop time step'):
- time_cond = np.full((batch,), time).astype(np.int32) # 创建与批次大小相同的时间条件
- x_start = Tensor(x_start) if x_start is not None else x_start
- self_cond = x_start if self.self_condition else None
- predict_noise, x_start, *_ = self.model_predictions(Tensor(img, mindspore.float32),
- Tensor(time_cond),
- self_cond,
- clip_denoise)
- predict_noise, x_start = predict_noise.asnumpy(), x_start.asnumpy()
- if time_next < 0:
- img = x_start
- continue
-
- alpha = self.alphas_cumprod[time]
- alpha_next = self.alphas_cumprod[time_next]
-
- sigma = eta * np.sqrt(((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)))
- c = np.sqrt(1 - alpha_next - sigma ** 2)
-
- noise = np.random.randn(*img.shape)
-
- img = x_start * np.sqrt(alpha_next) + c * predict_noise + sigma * noise
-
- img = self.unnormalize(img) # 反归一化图像
-
- return img
data_iterator中每次取出的数据集就是一个batch_size大小,每训练一个batch,self.step就会加1。
DDPM的trainer采用ema(指数移动平均)优化,ema不参与训练,只参与推理,比对变量直接赋值而言,移动平均得到的值在图像上更加平缓光滑,抖动性更小。具体代码参考代码仓中ema.py
- print('training start')
- with tqdm(initial=self.step, total=self.train_num_steps, disable=False) as pbar:
- total_loss = 0.
- for (img,) in data_iterator:
- model.set_train()
- time_emb = Tensor(
- np.random.randint(0, num_timesteps, (img.shape[0],)).astype(np.int32)) # 随机生成时间向量
- noise = Tensor(np.random.randn(*img.shape), mindspore.float32) # 生成噪声向量
-
- self_cond = random.random() < 0.5 if self.self_condition else False # 根据self_condition参数决定是否进行自我条件训练
- loss = train_step(img, time_emb, noise, self_cond) # 调用train_step函数返回损失值
-
- total_loss += float(loss.asnumpy()) # 累加损失值
-
- self.step += 1
- if self.step % gradient_accumulate_every == 0:
- self.ema.update() # 更新EMA模型的参数
- pbar.set_description(f'loss: {total_loss:.4f}')
- pbar.update(1)
- total_loss = 0.
-
- accumulate_step = self.step // gradient_accumulate_every
- accumulate_remain_step = self.step % gradient_accumulate_every
- if self.step != 0 and accumulate_step % self.save_and_sample_every == 0 and accumulate_remain_step == 0:
- self.ema.set_train(False)
- self.ema.synchronize()
- batches = num_to_groups(self.num_samples, self.batch_size)
- all_images_list = list(map(lambda n: self.ema.online_model.sample(batch_size=n), batches))
- self.save_images(all_images_list, accumulate_step) # 保存生成的图像
- self.save(accumulate_step) # 保存模型的参数
- self.ema.desynchronize()
-
- if self.step >= gradient_accumulate_every * self.train_num_steps:
- break
-
- print('training complete')
来源:一文读懂Stable Diffusion 论文原理+代码超详细解读 - 知乎 (zhihu.com)
AutoEncoderKL 编码器已提前训练好,参数是固定的。训练阶段该模块负责将输入数据集映射到latent space,然后latent space的样本再继续进入扩散模型进行扩散。这一过程在Stable Diffusion代码中被称为 encode_first_stage:
- def get_input(self, x, c):
- # 检查输入 x 的维度是否为3。如果是,则通过使用切片操作 [..., None] 在最后添加一个额外的维度,将其转换为4维张量。
- if len(x.shape) == 3:
- x = x[..., None]
- # 维度转置操作,将 x 的维度顺序从原来的 (batch_size, height, width, channels) 转换为 (batch_size, channels, height, width)。
- x = self.transpose(x, (0, 3, 1, 2))
- # 对输入 x 进行编码操作,并乘以一个名为 scale_factor 的常量。然后,使用 stop_gradient 方法对结果进行梯度停止,即在计算梯度时不会考虑这个部分。
- z = ops.stop_gradient(self.scale_factor * self.first_stage_model.encode(x))
-
- return z, c
/encoders/modules.py
将控制条件编码为向量。其核心模块class TextEncoder(nn.Cell)构建函数如下:
- def construct(self, text):
- bsz, ctx_len = text.shape
- flatten_id = text.flatten() # 将输入文本 text 展平为一维张量。展平操作将多维数组转换为一维,保留原始元素顺序。
- gather_result = self.gather(self.embedding_table, flatten_id, 0) # 从 embedding_table 中根据 flatten_id 提取对应的嵌入向量。
- x = self.reshape(gather_result, (bsz, ctx_len, -1)) # 重塑操作,将 gather_result 重新调整为指定形状 (bsz, ctx_len, -1) 的张量。
- x = x + self.positional_embedding # 引入位置编码
- x = x.transpose(1, 0, 2) # 对 x 进行维度转置操作,将维度顺序从原来的 (bsz, ctx_len, -1) 转换为 (ctx_len, bsz, -1)。
- x = self.transformer_layer(x) # 再次对 x 进行维度转置操作,将维度顺序从 (ctx_len, bsz, -1) 转换回 (bsz, ctx_len, -1)。
- x = x.transpose(1, 0, 2)
- x = self.ln_final(x)
- return x
- # 这段代码根据条件选择性地创建并添加 AttentionBlock 或 SpatialTransformer 对象到 layers 列表中,并将 layers 列表作为一个整体追加到 self.input_blocks 列表
- layers.append(AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
- use_checkpoint=use_checkpoint, dtype=self.dtype, dropout=self.dropout, use_linear=use_linear_in_transformer
- )
- )
- self.input_blocks.append(layers)
可以看出UNet的每个中间层都会拼接一次SpatialTransformer模块,该模块对应,使用 Attention 机制来更好的学习文本与图像的匹配关系。
- # 对输入数据进行一系列的处理和变换操作,并生成模型的输出
- def construct(self, x, timesteps=None, context=None, y=None):
- """
- Apply the model to an input batch.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :param context: conditioning plugged in via crossattn
- :param y: an [N] Tensor of labels, if class-conditional.
- :return: an [N x C x ...] Tensor of outputs.
- """
-
- assert (y is not None) == (
- self.num_classes is not None
- ), "must specify y if and only if the model is class-conditional"
-
- # 计算了时间步嵌入(timestep embedding)。首先,使用 timestep_embedding 方法生成时间步嵌入向量 t_emb,并根据 self.model_channels 进行相应的通道调整;然后,使用 self.time_embed 方法对 t_emb 进行进一步的时间嵌入处理,得到最终的嵌入向量 emb。
- hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
- emb = self.time_embed(t_emb)
-
- # 如果模型是类别条件的,则进一步处理嵌入向量 emb。首先,使用断言语句验证标签 y 的形状是否与输入 x 的批次维度匹配;然后,使用 self.label_emb 方法对标签 y 进行嵌入处理,并将其与嵌入向量 emb 相加。
- if self.num_classes is not None:
- assert y.shape == (x.shape[0],)
- emb = emb + self.label_emb(y)
-
- h = x
- for celllist in self.input_blocks:
- for cell in celllist:
- h = cell(h, emb, context)
- hs.append(h)
-
- for module in self.middle_block:
- h = module(h, emb, context)
-
- hs_index = -1
- for celllist in self.output_blocks:
- h = self.cat((h, hs[hs_index]))
- for cell in celllist:
- h = cell(h, emb, context)
- hs_index -= 1
-
- if self.predict_codebook_ids:
- return self.id_predictor(h)
- else:
- return self.out(h)
扩散模型,用于生成对应采样时间t的样本
- def p_losses(self, x_start, cond, t, noise=None):
- noise = ms.numpy.randn(x_start.shape) # 生成与 x_start 相同形状的随机噪声 noise
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # 基于给定的输入 x_start、时间步 t 和噪声 noise 生成加噪后的样本 x_noisy
- model_output = self.apply_model(x_noisy, t, cond) // UNet预测的噪声,cond表示FrozenCLIPEmbedder生成的条件
-
- # 根据参数化方式(parameterization),选择目标值 target。如果参数化方式是 "x0",则将目标值设置为输入 x_start;如果是 "eps",则将目标值设置为随机噪声 noise。如果参数化方式不在这两种情况中,则抛出异常。
- if self.parameterization == "x0":
- target = x_start
- elif self.parameterization == "eps":
- target = noise
- else:
- raise NotImplementedError()
-
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) //计算预测noise与真实noise的损失值。mean=False 表示不进行均值操作。然后,在维度 [1, 2, 3] 上进行均值操作,得到简单损失 loss_simple。
-
- logvar_t = self.logvar[t]
- loss = loss_simple / ops.exp(logvar_t) + logvar_t
- loss = self.l_simple_weight * loss.mean()
-
- loss_vlb = self.get_loss(model_output, target, mean=False).mean((1, 2, 3))
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
- loss += (self.original_elbo_weight * loss_vlb) # 最后,将乘积结果与原始ELBO(Evidence Lower Bound)权重 self.original_elbo_weight 相加,更新总体损失 loss。
-
- return loss
self.apply_model代码如下:
- # 该方法将噪声样本输入模型进行处理,并生成重构的输出结果
- # 输入的噪声样本 x_noisy、时间步 t 和条件 cond。首先使用 ops.cast() 将输入的 x_noisy 和 cond 强制转换为指定的数据类型 self.dtype。
- def apply_model(self, x_noisy, t, cond, return_ids=False):
- x_noisy = ops.cast(x_noisy, self.dtype)
- cond = ops.cast(cond, self.dtype)
-
- # 检查条件 cond 是否为字典类型。如果是,则表示为混合情况,不做任何操作。
- # 如果条件不是字典类型,则根据模型的 conditioning_key 属性,选择键值对的键名 key。
- # 如果 conditioning_key 是 'concat',则将键名设置为 'c_concat';否则,将键名设置为 'c_crossattn'。然后,将条件 cond 转换为包含单个键值对的字典。
- if isinstance(cond, dict):
- # hybrid case, cond is expected to be a dict
- pass
- else:
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
- cond = {key: cond}
-
- x_recon = self.model(x_noisy, t, **cond) // self.model表示UNet模型。根据传递给模型的参数,模型将生成重构的输出结果 x_recon。
-
- # 对重构的输出进行处理,并返回最终的结果。如果 x_recon 是一个元组类型且 return_ids 为假(即不返回标识符),则返回元组中的第一个元素;否则,直接返回 x_recon。
- if isinstance(x_recon, tuple) and not return_ids:
- return x_recon[0]
- else:
- return x_recon
LDM将损失函数反向传播来更新UNet模型的参数,AutoEncoderKL 和 FrozenCLIPEmbedder的参数在该反向传播中不会被更新。
从上述代码可以看出UNet的每个中间层都会拼接一次SpatialTransformer模块,该模块对应,使用 Attention 机制来更好的学习文本与图像的匹配关系。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。