当前位置:   article > 正文

天才程序员周弈帆 | Stable Diffusion 解读(三):原版实现源码解读(篇幅略长,建议收藏!)_stable diffusion代码

stable diffusion代码

本文来源公众号“天才程序员周弈帆”,仅用于学术分享,侵权删,干货满满。

原文链接:Stable Diffusion 解读(三):原版实现源码解读

天才程序员周弈帆 | Stable Diffusion 解读(一):回顾早期工作-CSDN博客

天才程序员周弈帆 | Stable Diffusion 解读(二):论文精读-CSDN博客

看完了Stable Diffusion的论文,在最后这几篇文章里,我们来学习Stable Diffusion的代码实现。具体来说,我们会学习Stable Diffusion官方仓库及Diffusers开源库中有关采样算法和U-Net的代码,而不会学习有关训练、VAE、text encoder (CLIP) 的代码。如今大多数工作都只会用到预训练的Stable Diffusion,只学采样算法和U-Net代码就能理解大多数工作了。

受字数限制,Diffusers的介绍会放到下一篇文章里。

建议读者在阅读本文之前了解DDPM、ResNet、U-Net、Transformer。

本文用到的Stable Diffusion版本是v1.5。Diffusers版本是0.25.0。为了提升可读性,本文对源代码做了一定的精简,部分不会运行到的分支会被略过。

1 算法梳理

在正式读代码之前,我们先用伪代码梳理一下Stable Diffusion的采样过程,并回顾一下U-Net架构的组成。实现Stable Diffusion的代码库有很多,各个库之间的API差异很大。但是,它们实际上都是在描述同一个算法,同一个模型。如果我们理解了算法和模型本身,就可以在学习时主动去找一个算法对应哪一段代码,而不是被动地去理解每一行代码在干什么。

1.1 LDM 采样算法

让我们从最早的DDPM开始,一步一步还原Latent Diffusion Model (LDM)的采样算法。DDPM的采样算法如下所示:

  1. def ddpm_sample(image_shape):
  2.   ddpm_scheduler = DDPMScheduler()
  3.   unet = UNet()
  4.   xt = randn(image_shape)
  5.   T = 1000
  6.   for t in T ... 1:
  7.     eps = unet(xt, t)
  8.     std = ddpm_scheduler.get_std(t)
  9.     xt = ddpm_scheduler.get_xt_prev(xt, t, eps, std)
  10.   return xt

在DDPM的实现中,一般会有一个类专门维护扩散模型的alpha,beta等变量。我们这里把这个类称为DDPMScheduler。此外,DDPM会用到一个U-Net神经网络unet,用于计算去噪过程中图像应该去除的噪声eps。准备好这两个变量后,就可以用randn()标准正态分布中采样一个纯噪声图像xt。它会被逐渐去噪,最终变成一幅图片。去噪过程中,时刻t会从总时刻T遍历至1(总时刻T一般取1000)。在每一轮去噪步骤中,U-Net会根据这一时刻的图像xt当前时间戳t估计出此刻应去除的噪声eps,根据xteps就能知道下一步图像的均值。除了均值,我们还要获取下一步图像的方差,这一般可以从DDPM调度类中直接获取。有了下一步图像的均值和方差,我们根据DDPM的公式,就能采样出下一步的图像。反复执行去噪循环,xt会从纯噪声图像变成一幅有意义的图像。

DDIM对DDPM的采样过程做了两点改进:1) 去噪的有效步数可以少于T步,由另一个变量ddim_steps决定;2) 采样的方差大小可以由eta决定。

因此,改进后的DDIM算法可以写成这样:

  1. def ddim_sample(image_shape, ddim_steps = 20, eta = 0):
  2. ddim_scheduler = DDIMScheduler()
  3. unet = UNet()
  4. xt = randn(image_shape)
  5. T = 1000
  6. timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
  7. for t in timesteps:
  8. eps = unet(xt, t)
  9. std = ddim_scheduler.get_std(t, eta)
  10. xt = ddim_scheduler.get_xt_prev(xt, t, eps, std)
  11. return xt
'
运行

其中,ddim_steps是去噪循环的执行次数。根据ddim_steps,DDIM调度器可以生成所有被使用到的t。比如对于T=1000, ddim_steps=20,被使用到的就只有[1000, 950, 900, ..., 50]这20个时间戳,其他时间戳就可以跳过不算了。eta会被用来计算方差,一般这个值都会设成0

DDIM是早期的加速扩散模型采样的算法。如今有许多比DDIM更好的采样方法,但它们多数都保留了stepseta这两个参数。因此,在使用所有采样方法时,我们可以不用关心实现细节,只关注多出来的这两个参数。

在DDIM的基础上,LDM从生成像素空间上的图像变为生成隐空间上的图像。隐空间图像需要再做一次解码才能变回真实图像。从代码上来看,使用LDM后,只需要多准备一个VAE,并对最后的隐空间图像zt解码。

  1. def ldm_ddim_sample(image_shape, ddim_steps = 20, eta = 0):
  2.   ddim_scheduler = DDIMScheduler()
  3.   vae = VAE()
  4.   unet = UNet()
  5.   zt = randn(image_shape)
  6.   T = 1000
  7.   timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
  8.   for t in timesteps:
  9.     eps = unet(zt, t)
  10.     std = ddim_scheduler.get_std(t, eta)
  11.     zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
  12.   xt = vae.decoder.decode(zt)
  13.   return xt

而想用LDM实现文生图,则需要给一个额外的文本输入text。文本编码器会把文本编码成张量c,输入进unet。其他地方的实现都和之前的LDM一样。

  1. def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0):
  2.   ddim_scheduler = DDIMScheduler()
  3.   vae = VAE()
  4.   unet = UNet()
  5.   zt = randn(image_shape)
  6.   T = 1000
  7.   timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
  8.   text_encoder = CLIP()
  9.   c = text_encoder.encode(text)
  10.   for t = timesteps:
  11.     eps = unet(zt, t, c)
  12.     std = ddim_scheduler.get_std(t, eta)
  13.     zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
  14.   xt = vae.decoder.decode(zt)
  15.   return xt

最后这个能实现文生图的LDM就是我们熟悉的Stable Diffusion。Stable Diffusion的采样算法看上去比较复杂,但如果能够从DDPM开始把各个功能都拆开来看,理解起来就不是那么困难了。

1.2 U-Net 结构组成

Stable Diffusion代码实现中的另一个重点是去噪网络U-Net的实现。仿照上一节的学习方法,我们来逐步学习Stable Diffusion中的U-Net是怎么从最经典的纯卷积U-Net逐渐发展而来的。

最早的U-Net的结构如下图所示:

可以看出,U-Net的结构有以下特点

  • 整体上看,U-Net由若干个大层组成。特征在每一大层会被下采样成尺寸更小的特征,再被上采样回原尺寸的特征。整个网络构成一个U形结构

  • 下采样后,特征的通道数会变多。一般情况下,每次下采样后图像尺寸减半,通道数翻倍。上采样过程则反之。

  • 为了防止信息在下采样的过程中丢失,U-Net每一大层在下采样前的输出会作为额外输入拼接到每一大层上采样前的输入上。这种数据连接方式类似于ResNet中的「短路连接」

DDPM则使用了一种改进版的U-Net。改进主要有两点:

  • 原来的卷积层被替换成了ResNet中的残差卷积模块。每一大层有若干个这样的子模块。对于较深的大层,残差卷积模块后面还会接一个自注意力模块。

  • 原来模型每一大层只有一个短路连接。现在每个大层下采样部分的每个子模块的输出都会额外输入到其对称的上采样部分的子模块上。直观上来看,就是短路连接更多了一点,输入信息更不容易在下采样过程中丢失。

最后,LDM提出了一种给U-Net添加额外约束信息的方法:把U-Net中的自注意力模块换成交叉注意力模块。具体来说,DDPM的U-Net的自注意力模块被换成了标准的Transformer模块。约束信息可以作为Cross Attention的K, V输入进模块中。

Stable Diffusion的U-Net还在结构上有少许修改,该U-Net的每一大层都有Transformer块,而不是只有较深的大层有。

至此,我们已经学完了Stable Diffusion的采样原理和U-Net结构。接下来我们来看一看它们在不同框架下的代码实现。

2 Stable Diffusion 官方 GitHub 仓库

2.1 安装

克隆仓库后,照着官方Markdown文档安装即可。

git clone git@github.com:CompVis/stable-diffusion.git

先用下面的命令创建conda环境,此后ldm环境就是运行Stable Diffusiion的conda环境。

  1. conda env create -f environment.yaml
  2. conda activate ldm

之后去网上下一个Stable Diffusion的模型文件。比较常见一个版本是v1.5,该模型在Hugging Face上:https://huggingface.co/runwayml/stable-diffusion-v1-5 (推荐下载v1-5-pruned.ckpt)。下载完毕后,把模型软链接到指定位置。

  1. mkdir -p models/ldm/stable-diffusion-v1/
  2. ln -s <path/to/model.ckpt> models/ldm/stable-diffusion-v1/model.ckpt 

准备完毕后,只要输入下面的命令,就可以生成实现文生图了。

python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" 

在默认的参数下,“一幅骑着马的飞行员的照片”的绘制结果会被保存在outputs/txt2img-samples中。你也可以通过--outdir <dir>参数来指定输出到的文件夹。我得到的一些绘制结果为:

【说明】如果你在安装时碰到了错误,可以在搜索引擎上或者GitHub的issue里搜索,一般都能搜到其他人遇到的相同错误。

2.2 主函数

接下来,我们来探究一下scripts/txt2img.py的执行过程。为了方便阅读,我们可以简化代码中的命令行处理,得到下面这份精简代码。(你可以把这份代码复制到仓库根目录下的一个新Python脚本里并直接运行。别忘了修改代码中的模型路径

  1. import os
  2. import torch
  3. import numpy as np
  4. from omegaconf import OmegaConf
  5. from PIL import Image
  6. from tqdm import tqdm, trange
  7. from einops import rearrange
  8. from pytorch_lightning import seed_everything
  9. from torch import autocast
  10. from torchvision.utils import make_grid
  11. from ldm.util import instantiate_from_config
  12. from ldm.models.diffusion.ddim import DDIMSampler
  13. def load_model_from_config(config, ckpt, verbose=False):
  14.     print(f"Loading model from {ckpt}")
  15.     pl_sd = torch.load(ckpt, map_location="cpu")
  16.     if "global_step" in pl_sd:
  17.         print(f"Global Step: {pl_sd['global_step']}")
  18.     sd = pl_sd["state_dict"]
  19.     model = instantiate_from_config(config.model)
  20.     m, u = model.load_state_dict(sd, strict=False)
  21.     if len(m) > 0 and verbose:
  22.         print("missing keys:")
  23.         print(m)
  24.     if len(u) > 0 and verbose:
  25.         print("unexpected keys:")
  26.         print(u)
  27.     model.cuda()
  28.     model.eval()
  29.     return model
  30. def main():
  31.     seed = 42
  32.     config = 'configs/stable-diffusion/v1-inference.yaml'
  33.     ckpt = 'ckpt/v1-5-pruned.ckpt'
  34.     outdir = 'tmp'
  35.     n_samples = batch_size = 3
  36.     n_rows = batch_size
  37.     n_iter = 2
  38.     prompt = 'a photograph of an astronaut riding a horse'
  39.     data = [batch_size * [prompt]]
  40.     scale = 7.5
  41.     C = 4
  42.     f = 8
  43.     H = W = 512
  44.     ddim_steps = 50
  45.     ddim_eta = 0.0
  46.     seed_everything(seed)
  47.     config = OmegaConf.load(config)
  48.     model = load_model_from_config(config, ckpt)
  49.     device = torch.device(
  50.         "cuda"if torch.cuda.is_available() else torch.device("cpu")
  51.     model = model.to(device)
  52.     sampler = DDIMSampler(model)
  53.     os.makedirs(outdir, exist_ok=True)
  54.     outpath = outdir
  55.     sample_path = os.path.join(outpath, "samples")
  56.     os.makedirs(sample_path, exist_ok=True)
  57.     grid_count = len(os.listdir(outpath)) - 1
  58.     start_code = None
  59.     precision_scope = autocast
  60.     with torch.no_grad():
  61.         with precision_scope("cuda"):
  62.             with model.ema_scope():
  63.                 all_samples = list()
  64.                 for n in trange(n_iter, desc="Sampling"):
  65.                     for prompts in tqdm(data, desc="data"):
  66.                         uc = None
  67.                         if scale != 1.0:
  68.                             uc = model.get_learned_conditioning(
  69.                                 batch_size * [""])
  70.                         if isinstance(prompts, tuple):
  71.                             prompts = list(prompts)
  72.                         c = model.get_learned_conditioning(prompts)
  73.                         shape = [C, H // f, W // f]
  74.                         samples_ddim, _ = sampler.sample(S=ddim_steps,
  75.                                                          conditioning=c,
  76.                                                          batch_size=n_samples,
  77.                                                          shape=shape,
  78.                                                          verbose=False,
  79.                                                          unconditional_guidance_scale=scale,
  80.                                                          unconditional_conditioning=uc,
  81.                                                          eta=ddim_eta,
  82.                                                          x_T=start_code)
  83.                         x_samples_ddim = model.decode_first_stage(samples_ddim)
  84.                         x_samples_ddim = torch.clamp(
  85.                             (x_samples_ddim + 1.0) / 2.0min=0.0max=1.0)
  86.                         all_samples.append(x_samples_ddim)
  87.                 grid = torch.stack(all_samples, 0)
  88.                 grid = rearrange(grid, 'n b c h w -> (n b) c h w')
  89.                 grid = make_grid(grid, nrow=n_rows)
  90.                 # to image
  91.                 grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
  92.                 img = Image.fromarray(grid.astype(np.uint8))
  93.                 img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
  94.                 grid_count += 1
  95.     print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
  96.           f" \nEnjoy.")
  97. if __name__ == "__main__":
  98.     main()

抛开前面一大堆初始化操作,代码的核心部分只有下面几行。

  1. uc = None
  2. if scale != 1.0:
  3.     uc = model.get_learned_conditioning(
  4.         batch_size * [""])
  5. if isinstance(prompts, tuple):
  6.     prompts = list(prompts)
  7. c = model.get_learned_conditioning(prompts)
  8. shape = [C, H // f, W // f]
  9. samples_ddim, _ = sampler.sample(S=ddim_steps,
  10.                                   conditioning=c,
  11.                                   batch_size=n_samples,
  12.                                   shape=shape,
  13.                                   verbose=False,
  14.                                   unconditional_guidance_scale=scale,
  15.                                   unconditional_conditioning=uc,
  16.                                   eta=ddim_eta,
  17.                                   x_T=start_code)
  18. x_samples_ddim = model.decode_first_stage(samples_ddim)

我们来逐行分析一下这段代码。一开始的几行是执行Classifier-Free Guidance (CFG)uc表示的是CFG中的无约束下的约束张量。scale表示的是执行CFG的程度,scale不等于1.0即表示启用CFG。model.get_learned_conditioning表示用CLIP把文本编码成张量。对于文本约束的模型,无约束其实就是输入文本为空字符串("")。因此,在代码中,若启用了CFG,则会用CLIP编码空字符串,编码结果为uc

如果你没学过CFG,也不用担心。你可以暂时不要去理解上面这段话。等读完了后文中有关CFG的代码后,你差不多就能理解CFG的用法了。

  1. uc = None
  2. if scale != 1.0:
  3.     uc = model.get_learned_conditioning(
  4.         batch_size * [""])

之后的几行是在把用户输入的文本编码成张量。同样,model.get_learned_conditioning表示用CLIP把输入文本编码成张量c

  1. if isinstance(prompts, tuple):
  2.     prompts = list(prompts)
  3. c = model.get_learned_conditioning(prompts)

接着是用扩散模型的采样器生成图片。在这份代码中,sampler是DDIM采样器,sampler.sample函数直接完成了图像生成。

  1. shape = [C, H // f, W // f]
  2. samples_ddim, _ = sampler.sample(S=ddim_steps,
  3.                                   conditioning=c,
  4.                                   batch_size=n_samples,
  5.                                   shape=shape,
  6.                                   verbose=False,
  7.                                   unconditional_guidance_scale=scale,
  8.                                   unconditional_conditioning=uc,
  9.                                   eta=ddim_eta,
  10.                                   x_T=start_code)

后,LDM生成的隐空间图片被VAE解码成真实图片。函数model.decode_first_stage负责图片解码。x_samples_ddim在后续的代码中会被后处理成正确格式的RGB图片,并输出至文件里。

x_samples_ddim = model.decode_first_stage(samples_ddim)

Stable Diffusion 官方实现的主函数主要就做了这些事情。这份实现还是有一些凌乱的。采样算法的一部分内容被扔到了主函数里,另一部分放到了DDIM采样器里。在阅读官方实现的源码时,既要去读主函数里的内容,也要去读采样器里的内容。

接下来,我们来看一看DDIM采样器的部分代码,学完采样算法的剩余部分的实现。

2.3 DDIM 采样器

回头看主函数的前半部分,DDIM采样器是在下面的代码里导入的:

from ldm.models.diffusion.ddim import DDIMSampler

跳转到ldm/models/diffusion/ddim.py文件,我们可以找到DDIMSampler的实现。

先看一下这个类的构造函数。构造函数主要是把U-Net model给存了下来。后文中的self.model都指的是U-Net。

  1. def __init__(self, model, schedule="linear", **kwargs):
  2.     super().__init__()
  3.     self.model = model
  4.     self.ddpm_num_timesteps = model.num_timesteps
  5.     self.schedule = schedule
  6. # in main
  7. config = OmegaConf.load(config)
  8. model = load_model_from_config(config, ckpt)
  9. model = model.to(device)
  10. sampler = DDIMSampler(model)

再沿着类的self.sample方法,看一下DDIM采样的实现代码。以下是self.sample方法的主要内容。这个方法其实就执行了一个self.make_schedule,之后把所有参数原封不动地传到了self.ddim_sampling里。

  1. @torch.no_grad()
  2. def sample(self,
  3.             S,
  4.             batch_size,
  5.             shape,
  6.             conditioning=None,
  7.             ...
  8.             ):
  9.     if conditioning is not None:
  10.         ...
  11.     self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
  12.     # sampling
  13.     C, H, W = shape
  14.     size = (batch_size, C, H, W)
  15.     print(f'Data shape for DDIM sampling is {size}, eta {eta}')
  16.     samples, intermediates = self.ddim_sampling(...)

self.make_schedule用于预处理扩散模型的中间计算参数。它的大部分实现细节可以略过。DDIM用到的有效时间戳列表就是在这个函数里设置的,该列表通过make_ddim_timesteps获取,并保存在self.ddim_timesteps中。此外,由ddim_eta决定的扩散模型的方差也是在这个方法里设置的。大致扫完这个方法后,我们可以直接跳到self.ddim_sampling的代码。

  1. def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
  2.     self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
  3.                                               num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
  4.     ...

穿越重重的嵌套,我们总算能看到DDIM采样的实现方法self.ddim_sampling了。它的主要内容如下所示:

  1. @torch.no_grad()
  2. def ddim_sampling(self, ...):
  3.     device = self.model.betas.device
  4.     b = shape[0]
  5.     img = torch.randn(shape, device=device)
  6.     timesteps = self.ddim_timesteps
  7.     intermediates = ...
  8.     time_range = np.flip(timesteps)
  9.     total_steps = timesteps.shape[0]
  10.     iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
  11.     for i, step in enumerate(iterator):
  12.         index = total_steps - i - 1
  13.         ts = torch.full((b,), step, device=device, dtype=torch.long)
  14.         outs = self.p_sample_ddim(img, cond, ts, ...)
  15.         img, pred_x0 = outs
  16.     return img, intermediates

这段代码和我们之前自己写的伪代码非常相似。一开始,方法获取了在make_schedule里初始化的DDIM有效时间戳列表self.ddim_timesteps,并预处理成一个iterator。该迭代器用于控制DDIM去噪循环。每一轮循环会根据当前时刻的图像img和时间戳ts计算下一步的图像img。具体来说,代码每次用当前的时间戳step创建一个内容全部为step,形状为(b,)的张量ts。该张量会和当前的隐空间图像img,约束信息张量cond一起传给执行一轮DDIM去噪的p_sample_ddim方法。p_sample_ddim方法会返回下一步的图像img。最后,经过多次去噪后,ddim_sampling方法将去噪后的隐空间图像img返回。

p_sample_ddim里的p_sample看上去似乎意义不明,实际上这个叫法来自于DDPM论文。在DDPM论文中,扩散模型的前向过程用字母q表示,反向过程用字母p表示。因此,反向过程的一轮去噪在代码里被叫做p_sample

最后来看一下p_sample_ddim这个方法,它的主体部分如下:

  1. @torch.no_grad()
  2. def p_sample_ddim(self, x, c, t, ...):
  3.     b, *_, device = *x.shape, x.device
  4.     if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
  5.         e_t = self.model.apply_model(x, t, c)
  6.     else:
  7.         x_in = torch.cat([x] * 2)
  8.         t_in = torch.cat([t] * 2)
  9.         c_in = torch.cat([unconditional_conditioning, c])
  10.         e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
  11.         e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
  12.     # Prepare variables
  13.     ...
  14.     # current prediction for x_0
  15.     pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  16.     if quantize_denoised:
  17.         pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
  18.     # direction pointing to x_t
  19.     dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
  20.     noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
  21.     if noise_dropout > 0.:
  22.         noise = torch.nn.functional.dropout(noise, p=noise_dropout)
  23.     x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
  24.     return x_prev, pred_x0

方法的内容大致可以拆成三段:首先,方法调用U-Net self.model,使用CFG来计算除这一轮该去掉的噪声e_t。然后,方法预处理出DDIM的中间变量。最后,方法根据DDIM的公式,计算出这一轮去噪后的图片x_prev。我们着重看第一部分的代码。

不启用CFG时,方法直接通过self.model.apply_model(x, t, c)调用U-Net,算出这一轮的噪声e_t。而想启用CFG,需要输入空字符串的约束张量unconditional_conditioning,且CFG的强度unconditional_guidance_scale不为1。CFG的执行过程是:对U-Net输入不同的约束c,先用空字符串约束得到一个预测噪声e_t_uncond,再用输入的文本约束得到一个预测噪声e_t。之后令e_t = et_uncond + scale * (e_t - e_t_uncond)scale大于1,即表明我们希望预测噪声更加靠近有输入文本的那一个。直观上来看,scale越大,最后生成的图片越符合输入文本,越偏离空文本。下面这段代码正是实现了上述这段逻辑,只不过代码使用了一些数据拼接技巧,让空字符串约束下和输入文本约束下的结果在一次U-Net推理中获得

  1. if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
  2.     e_t = self.model.apply_model(x, t, c)
  3. else:
  4.     x_in = torch.cat([x] * 2)
  5.     t_in = torch.cat([t] * 2)
  6.     c_in = torch.cat([unconditional_conditioning, c])
  7.     e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
  8.     e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

p_sample_ddim 方法的后续代码都是在实现下面这个DDIM采样公式。代码工工整整地计算了公式中的predicted_x0dir_xtnoise,非常易懂,没有需要特别注意的地方。

我们已经看完了p_sample_ddim的代码。该方法可以实现一步去噪操作。多次调用该方法去噪后,我们就能得到生成的隐空间图片。该图片会被返回到main函数里,被VAE的解码器解码成普通图片。至此,我们就学完了Stable Diffusion官方仓库的采样代码。

对照下面这份我们之前写的伪代码,我们再来梳理一下Stable Diffusion官方仓库的代码逻辑。官方仓库的采样代码一部分在main函数里,另一部分在ldm/models/diffusion/ddim.py里。main函数主要完成了编码约束文字、解码隐空间图像这两件事。剩下的DDIM采样以及各种Diffusion图像编辑功能都是在ldm/models/diffusion/ddim.py文件中实现的。

  1. def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0)
  2.   ddim_scheduler = DDIMScheduler()
  3.   vae = VAE()
  4.   unet = UNet()
  5.   zt = randn(image_shape)
  6.   eta = input()
  7.   T = 1000
  8.   timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...]
  9.   text_encoder = CLIP()
  10.   c = text_encoder.encode(text)
  11.   for t = timesteps:
  12.     eps = unet(zt, t, c)
  13.     std = ddim_scheduler.get_std(t, eta)
  14.     zt = ddim_scheduler.get_xt_prev(zt, t, eps, std)
  15.   xt = vae.decoder.decode(zt)
  16.   return xt

在学习代码时,要着重学习DDIM采样器部分的代码。大部分基于Diffusion的图像编辑技术都是在DDIM采样的中间步骤中做文章,只要学懂了DDIM采样的代码,学相关图像编辑技术就会非常轻松。除此之外,和LDM相关的文字约束编码、隐空间图像编码解码的接口函数也需要熟悉,不少技术会调用到这几项功能。

还有一些Diffusion相关工作会涉及U-Net的修改。接下来,我们就来看Stable Diffusion官方仓库中U-Net的实现。

2.4 U-Net

我们来回头看一下main函数和DDIM采样中U-Net的调用逻辑。和U-Net有关的代码如下所示。LDM模型类 model在主函数中通过load_model_from_config从配置文件里创建,随后成为了sampler的成员变量。在DDIM去噪循环中,LDM模型里的U-Net会在self.model.apply_model方法里被调用。

  1. # main.py
  2. config = 'configs/stable-diffusion/v1-inference.yaml'
  3. config = OmegaConf.load(config)
  4. model = load_model_from_config(config, ckpt)
  5. sampler = DDIMSampler(model)
  6. # ldm/models/diffusion/ddim.py
  7. e_t = self.model.apply_model(x, t, c)

为了知道U-Net是在哪个类里定义的,我们需要打开配置文件 configs/stable-diffusion/v1-inference.yaml。该配置文件有这样一段话:

  1. model:
  2.   target: ldm.models.diffusion.ddpm.LatentDiffusion
  3.   params:
  4.     conditioning_key: crossattn
  5.     unet_config:
  6.         target: ldm.modules.diffusionmodules.openaimodel.UNetModel

根据这段话,我们知道LDM类定义在ldm/models/diffusion/ddpm.pyLatentDiffusion里,U-Net类定义在ldm/modules/diffusionmodules/openaimodel.pyUNetModel里。一个LDM类有一个U-Net类的实例。我们先简单看一看LatentDiffusion的实现。

ldm/models/diffusion/ddpm.py原本来自DDPM论文的官方仓库,内含DDPM类的实现。DDPM类维护了扩散模型公式里的一些变量,同时维护了U-Net类的实例。LDM的作者基于之前DDPM的代码进行开发,定义了一个继承自DDPMLatentDiffusion类。除了DDPM本身的功能外,LatentDiffusion还维护了VAE(self.first_stage_model),CLIP(self.cond_stage_model)。也就是说,LatentDiffusion主要维护了扩散模型中间变量、U-Net、VAE、CLIP这四类信息。这样,所有带参数的模型都在LatentDiffusion里,我们可以从一个checkpoint文件中读取所有的模型的参数。相关代码定义代码如下:

把所有模型定义在一起有好处也有坏处。好处在于,用户想使用Stable Diffusion时,只需要下载一个checkpoint文件就行了。坏处在于,哪怕用户只改了某个子模型(如U-Net),为了保存整个模型,他还是得把其他子模型一起存下来。这其中存在着信息冗余,十分不灵活。Diffusers框架没有把模型全存在一个文件里,而是放到了一个文件夹里。

  1. class DDPM(pl.LightningModule):
  2.     # classic DDPM with Gaussian diffusion, in image space
  3.     def __init__(self,
  4.                  unet_config,
  5.                  ...):
  6.         self.model = DiffusionWrapper(unet_config, conditioning_key)
  7.         
  8. class LatentDiffusion(DDPM):
  9.     """main class"""
  10.     def __init__(self,
  11.                  first_stage_config,
  12.                  cond_stage_config,
  13.                  ...):
  14.         self.instantiate_first_stage(first_stage_config)
  15.         self.instantiate_cond_stage(cond_stage_config)

我们主要关注LatentDiffusion类的apply_model方法,它用于调用U-Net self.modelapply_model看上去有很长,但略过了我们用不到的一些代码后,整个方法其实非常短。一开始,方法对输入的约束信息编码cond做了一个前处理,判断约束是哪种类型。如论文里所描述的,LDM支持两种约束:将约束与输入拼接、将约束注入到交叉注意力层中。方法会根据self.model.conditioning_keyconcat还是crossattn,使用不同的约束方式。Stable Diffusion使用的是后者,即self.model.conditioning_key == crossattn。做完前处理后,方法执行了x_recon = self.model(x_noisy, t, **cond)。接下来的处理交给U-Net self.model来完成。

  1. def apply_model(self, x_noisy, t, cond, return_ids=False):
  2.     if isinstance(cond, dict):
  3.         # hybrid case, cond is exptected to be a dict
  4.         pass
  5.     else:
  6.         if not isinstance(cond, list):
  7.             cond = [cond]
  8.         key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
  9.         cond = {key: cond}
  10.     x_recon = self.model(x_noisy, t, **cond)
  11.     if isinstance(x_recon, tupleand not return_ids:
  12.         return x_recon[0]
  13.     else:
  14.         return x_recon

现在,我们跳转到ldm/modules/diffusionmodules/openaimodel.pyUNetModel里。UNetModel只定义了神经网络层的运算,没有多余的功能。我们只需要看它的__init__方法和forward方法。我们先来看较为简短的forward方法。

  1. def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
  2.     hs = []
  3.     t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
  4.     emb = self.time_embed(t_emb)
  5.     h = x.type(self.dtype)
  6.     for module in self.input_blocks:
  7.         h = module(h, emb, context)
  8.         hs.append(h)
  9.     h = self.middle_block(h, emb, context)
  10.     for module in self.output_blocks:
  11.         h = th.cat([h, hs.pop()], dim=1)
  12.         h = module(h, emb, context)
  13.     h = h.type(x.dtype)
  14.     return self.out(h)

forward方法的输入是x, timesteps, context,分别表示当前去噪时刻的图片、当前时间戳、文本约束编码。根据这些输入,forward会输出当前时刻应去除的噪声eps。一开始,方法会先对timesteps使用Transformer论文中介绍的位置编码timestep_embedding,得到时间戳的编码t_embt_emb再经过几个线性层,得到最终的时间戳编码emb。而context已经是CLIP处理过的编码,它不需要做额外的预处理。时间戳编码emb和文本约束编码context随后会注入到U-Net的所有中间模块中

  1. def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
  2.     hs = []
  3.     t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
  4.     emb = self.time_embed(t_emb)

经过预处理后,方法开始处理U-Net的计算。中间结果h会经过U-Net的下采样模块input_blocks,每一个子模块的临时输出都会被保存进一个栈hs里。

  1.  h = x.type(self.dtype)
  2. for module in self.input_blocks:
  3.     h = module(h, emb, context)
  4.     hs.append(h)

接着,h会经过U-Net的中间模块。

h = self.middle_block(h, emb, context)

随后,h开始经过U-Net的上采样模块output_blocks。此时每一个编码器子模块的临时输出会从栈hs里弹出,作为对应解码器子模块的额外输入。额外输入hs.pop()会与中间结果h拼接到一起输入进子模块里。

  1. for module in self.output_blocks:
  2.     h = th.cat([h, hs.pop()], dim=1)
  3.     h = module(h, emb, context)
  4. h = h.type(x.dtype)

最后,h会被输出层转换成一个通道数正确的eps张量。

return self.out(h)

这段代码的数据连接图如下所示:

在阅读__init__前,我们先看一下待会会用到的另一个模块类TimestepEmbedSequential的定义。在PyTorch中,一系列输入和输出都只有一个变量的模块在串行连接时,可以用串行模块类nn.Sequential来把多个模块合并简化成一个模块。而在扩散模型中,多数模块的输入是x, t, c三个变量,输出是一个变量。为了也能用类似的串行模块类把扩散模型的模块合并在一起,代码中包含了一个TimestepEmbedSequential类。它的行为类似于nn.Sequential,只不过它支持x, t, c的输入。forward中用到的多数模块都是通过TimestepEmbedSequential创建的

  1. class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
  2.     def forward(self, x, emb, context=None):
  3.         for layer in self:
  4.             if isinstance(layer, TimestepBlock):
  5.                 x = layer(x, emb)
  6.             elif isinstance(layer, SpatialTransformer):
  7.                 x = layer(x, context)
  8.             else:
  9.                 x = layer(x)
  10.         return x

看完了数据的计算过程,我们回头来看各个子模块在__init__方法中是怎么被详细定义的。__init__的主要内容如下:

  1. class UNetModel(nn.Module):
  2.     def __init__(self, ...):
  3.         self.time_embed = nn.Sequential(
  4.             linear(model_channels, time_embed_dim),
  5.             nn.SiLU(),
  6.             linear(time_embed_dim, time_embed_dim),
  7.         )
  8.         self.input_blocks = nn.ModuleList(
  9.             [
  10.                 TimestepEmbedSequential(
  11.                     conv_nd(dims, in_channels, model_channels, 3, padding=1)
  12.                 )
  13.             ]
  14.         )
  15.         for level, mult in enumerate(channel_mult):
  16.             for _ in range(num_res_blocks):
  17.                 layers = [
  18.                     ResBlock(...)]
  19.                 ch = mult * model_channels
  20.                 if ds in attention_resolutions:
  21.                      layers.append(
  22.                         AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...))
  23.                 self.input_blocks.append(TimestepEmbedSequential(*layers))
  24.             if level != len(channel_mult) - 1:
  25.                 out_ch = ch
  26.                 self.input_blocks.append(
  27.                     TimestepEmbedSequential(
  28.                         ResBlock(...)
  29.                         if resblock_updown
  30.                         else Downsample(...)
  31.                     )
  32.                 )
  33.         self.middle_block = TimestepEmbedSequential(
  34.             ResBlock(...),
  35.             AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...),
  36.             ResBlock(...),
  37.         )
  38.         self.output_blocks = nn.ModuleList([])
  39.         for level, mult in list(enumerate(channel_mult))[::-1]:
  40.             for i in range(num_res_blocks + 1):
  41.                 ich = input_block_chans.pop()
  42.                 layers = [
  43.                     ResBlock(...)
  44.                 ]
  45.                 ch = model_channels * mult
  46.                 if ds in attention_resolutions:
  47.                     layers.append(
  48.                         AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...)
  49.                     )
  50.                 if level and i == num_res_blocks:
  51.                     out_ch = ch
  52.                     layers.append(
  53.                         ResBlock(...)
  54.                         if resblock_updown
  55.                         else Upsample(...)
  56.                     )
  57.                     ds //= 2
  58.                 self.output_blocks.append(TimestepEmbedSequential(*layers))
  59.     self.out = nn.Sequential(
  60.             normalization(ch),
  61.             nn.SiLU(),
  62.             zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
  63.         )

__init__方法的代码很长。在阅读这样的代码时,我们不需要每一行都去细读,只需要理解代码能拆成几块,每一块在做什么即可__init__方法其实就是定义了forward中用到的5个模块,我们一个一个看过去即可。

  1. class UNetModel(nn.Module):
  2.     def __init__(self, ...):
  3.         self.time_embed = ...
  4.         self.input_blocks = nn.ModuleList(...)
  5.         for level, mult in enumerate(channel_mult):
  6.             ...
  7.         self.middle_block = ...
  8.         self.output_blocks = nn.ModuleList([])
  9.         for level, mult in list(enumerate(channel_mult))[::-1]:
  10.             ...
  11.     self.out = ...

先来看time_embed。回忆一下,在forward里,输入的整数时间戳会被正弦编码timestep_embedding(即Transformer中的位置编码)编码成一个张量。之后,时间戳编码处理模块time_embed用于进一步提取时间戳编码的特征。从下面的代码中可知,它本质上就是一个由两个普通线性层构成的模块。

  1. self.time_embed = nn.Sequential(
  2.             linear(model_channels, time_embed_dim),
  3.             nn.SiLU(),
  4.             linear(time_embed_dim, time_embed_dim),
  5.         )

再来看U-Net最后面的输出模块out。输出模块的结构也很简单,它主要包含了一个卷积层,用于把中间变量的通道数从dims变成model_channels

  1. self.out = nn.Sequential(
  2.             normalization(ch),
  3.             nn.SiLU(),
  4.             zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
  5.         )

接下来,我们把目光聚焦在U-Net的三个核心模块上:input_blocksmiddle_blockoutput_blocks。这三个模块的组成都很类似,都用到了残差块ResBlock和注意力块。稍有不同的是,input_blocks的每一大层后面都有一个下采样模块,output_blocks的每一大层后面都有一个上采样模块。上下采样模块的结构都很常规,与经典的U-Net无异。我们把学习的重点放在残差块和注意力块上。我们先看这两个模块的内部实现细节,再来看它们是怎么拼接起来的。

Stable Diffusion的U-Net中的ResBlock和原DDPM的U-Net的ResBlock功能完全一样,都是在普通残差块的基础上,支持时间戳编码的额外输入。具体来说,普通的残差块是由两个卷积模块和一条短路连接构成的,即y = x + conv(conv(x))。如果经过两个卷积块后数据的通道数发生了变化,则要在短路连接上加一个转换通道数的卷积,即y = conv(x) + conv(conv(x))

在这种普通残差块的基础上,扩散模型中的残差块还支持时间戳编码t的输入。为了把t和输入x的信息融合在一起,t会和经过第一个卷积后的中间结果conv(x)加在一起。可是,t的通道数和conv(x)的通道数很可能会不一样。通道数不一样的数据是不能直接加起来的。为此,每一个残差块中都有一个用于转换t通道数的线性层。这样,tconv(x)就能相加了。整个模块的计算可以表示成y=conv(x) + conv(conv(x) + linear(t))。残差块的示意图和源代码如下:

代码解析:

  1. class ResBlock(TimestepBlock):
  2.     def __init__(self, ...):
  3.         super().__init__()
  4.         ...
  5.         self.in_layers = nn.Sequential(
  6.             normalization(channels),
  7.             nn.SiLU(),
  8.             conv_nd(dims, channels, self.out_channels, 3, padding=1),
  9.         )
  10.         self.emb_layers = nn.Sequential(
  11.             nn.SiLU(),
  12.             linear(
  13.                 emb_channels,
  14.                 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
  15.             ),
  16.         )
  17.         self.out_layers = nn.Sequential(
  18.             normalization(self.out_channels),
  19.             nn.SiLU(),
  20.             nn.Dropout(p=dropout),
  21.             zero_module(
  22.                 conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
  23.             ),
  24.         )
  25.         if self.out_channels == channels:
  26.             self.skip_connection = nn.Identity()
  27.         elif use_conv:
  28.             self.skip_connection = conv_nd(
  29.                 dims, channels, self.out_channels, 3, padding=1
  30.             )
  31.         else:
  32.             self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
  33.     def forward(self, x, emb):
  34.         h = self.in_layers(x)
  35.         emb_out = self.emb_layers(emb).type(h.dtype)
  36.         while len(emb_out.shape) < len(h.shape):
  37.             emb_out = emb_out[..., None]
  38.         h = h + emb_out
  39.         h = self.out_layers(h)
  40.         return self.skip_connection(x) + h

代码中的in_layers是第一个卷积模块,out_layers是第二个卷积模块。skip_connection是用于调整短路连接通道数的模块。若输入输出的通道数相同,则该模块是一个恒等函数,不对数据做任何修改。emb_layers是调整时间戳编码通道数的线性层模块。这些模块的定义都在ResBlock__init__里。它们的结构都很常规,没有值得注意的地方。我们可以着重阅读模型的forward方法。

如前文所述,在forward中,输入x会先经过第一个卷积模块in_layers,再与经过了emb_layers调整的时间戳编码emb相加后,输入进第二个卷积模块out_layers。最后,做完计算的数据会和经过了短路连接的原输入skip_connection(x)加在一起,作为整个残差块的输出。

  1. def forward(self, x, emb):
  2.     h = self.in_layers(x)
  3.     emb_out = self.emb_layers(emb).type(h.dtype)
  4.     while len(emb_out.shape) < len(h.shape):
  5.         emb_out = emb_out[..., None]
  6.     h = h + emb_out
  7.     h = self.out_layers(h)
  8.     return self.skip_connection(x) + h

这里有一点实现细节需要注意。时间戳编码emb_out的形状是[n, c]。为了把它和形状为[n, c, h, w]的图片加在一起,需要把它的形状变成[n, c, 1, 1]后再相加(形状为[n, c, 1, 1]的数据在与形状为[n, c, h, w]的数据做加法时形状会被自动广播成[n, c, h, w])。在PyTorch中,x=x[..., None]可以在一个数据最后加一个长度为1的维度。比如对于形状为[n, c]tt[..., None]的形状就会是[n, c, 1]

残差块的内容到此结束。

我们接着来看注意力模块。在看模块的具体实现之前,我们先看一下源代码中有哪几种注意力模块。在U-Net的代码中,注意力模型是用以下代码创建的:

  1. if ds in attention_resolutions:
  2.     layers.append(
  3.         AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...)
  4.     )

第一行if ds in attention_resolutions:用于控制在U-Net的哪几个大层。Stable Diffusion每一大层都用了注意力模块,可以忽略这一行。随后,代码根据是否设置use_spatial_transformer来创建AttentionBlock或是SpatialTransformerAttentionBlock是DDPM中采样的普通自注意力模块,而SpatialTransformer是LDM中提出的支持额外约束的标准Transfomer块。Stable Diffusion使用的是SpatialTransformer。我们就来看一看这个模块的实现细节。

如前所述,SpatialTransformer使用的是标准的Transformer块,它和Transformer中的Transformer块完全一致。输入x先经过一个自注意力层,再过一个交叉注意力层。在此期间,约束编码c会作为交叉注意力层的K, V输入进模块。最后,数据经过一个全连接层。每一层的输入都会和输出做一个残差连接。

当然,标准Transformer是针对一维序列数据的。要把Transformer用到图像上,则需要把图像的宽高拼接到同一维,即对张量做形状变换n c h w -> n c (h * w)。做完这个变换后,就可以把数据直接输入进Transformer模块了。 这些图像数据与序列数据的适配都是在SpatialTransformer里完成的。SpatialTransformer类并没有直接实现Transformer块的细节,仅仅是U-Net和Transformer块之间的一个过渡。Transformer块的实现在它的一个子模块里。我们来看它的实现代码。

SpatialTransformer有两个卷积层proj_inproj_out,负责图像通道数与Transformer模块通道数之间的转换。SpatialTransformertransformer_blocks才是真正的Transformer模块。

  1. class SpatialTransformer(nn.Module):
  2.     def __init__(self, in_channels, n_heads, d_head,
  3.                  depth=1, dropout=0., context_dim=None):
  4.         super().__init__()
  5.         self.in_channels = in_channels
  6.         inner_dim = n_heads * d_head
  7.         self.norm = Normalize(in_channels)
  8.         self.proj_in = nn.Conv2d(in_channels,
  9.                                  inner_dim,
  10.                                  kernel_size=1,
  11.                                  stride=1,
  12.                                  padding=0)
  13.         self.transformer_blocks = nn.ModuleList(
  14.             [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
  15.                 for d in range(depth)]
  16.         )
  17.         self.proj_out = zero_module(nn.Conv2d(inner_dim,
  18.                                               in_channels,
  19.                                               kernel_size=1,
  20.                                               stride=1,
  21.                                               padding=0))

forward中,图像数据在进出Transformer模块前后都会做形状和通道数上的适配。运算结束后,结果和输入之间还会做一个残差连接。context就是约束信息编码,它会接入到交叉注意力层上。

  1. def forward(self, x, context=None):
  2.     b, c, h, w = x.shape
  3.     x_in = x
  4.     x = self.norm(x)
  5.     x = self.proj_in(x)
  6.     x = rearrange(x, 'b c h w -> b (h w) c')
  7.     for block in self.transformer_blocks:
  8.         x = block(x, context=context)
  9.     x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
  10.     x = self.proj_out(x)
  11.     return x + x_in

每一个Transformer模块的结构完全符合上文的示意图。如果你之前学过Transformer,那这些代码你会十分熟悉。我们快速把这部分代码浏览一遍。

  1. class BasicTransformerBlock(nn.Module):
  2.     def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
  3.         super().__init__()
  4.         self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
  5.         self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
  6.         self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
  7.                                     heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
  8.         self.norm1 = nn.LayerNorm(dim)
  9.         self.norm2 = nn.LayerNorm(dim)
  10.         self.norm3 = nn.LayerNorm(dim)
  11.         self.checkpoint = checkpoint
  12.     def forward(self, x, context=None):
  13.         x = self.attn1(self.norm1(x)) + x
  14.         x = self.attn2(self.norm2(x), context=context) + x
  15.         x = self.ff(self.norm3(x)) + x
  16.         return x

自注意力层和交叉注意力层都是用CrossAttention实现的。该模块与Transformer论文中的多头注意力机制完全相同。当forward的参数context=None时,模块其实只是一个提取特征的自注意力模块;而当context为约束文本的编码时,模块就是一个根据文本约束进行运算的交叉注意力模块。该模块用不到mask,相关的代码可以忽略。

  1. class CrossAttention(nn.Module):
  2.     def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
  3.         super().__init__()
  4.         inner_dim = dim_head * heads
  5.         context_dim = default(context_dim, query_dim)
  6.         self.scale = dim_head ** -0.5
  7.         self.heads = heads
  8.         self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
  9.         self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
  10.         self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
  11.         self.to_out = nn.Sequential(
  12.             nn.Linear(inner_dim, query_dim),
  13.             nn.Dropout(dropout)
  14.         )
  15.     def forward(self, x, context=None, mask=None):
  16.         h = self.heads
  17.         q = self.to_q(x)
  18.         context = default(context, x)
  19.         k = self.to_k(context)
  20.         v = self.to_v(context)
  21.         q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  22.         sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  23.         if exists(mask):
  24.             ...
  25.         # attention, what we cannot get enough of
  26.         attn = sim.softmax(dim=-1)
  27.         out = einsum('b i j, b j d -> b i d', attn, v)
  28.         out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
  29.         return self.to_out(out)

Transformer块的内容到此结束。看完了SpatialTransformerResBlock,我们可以回头去看模块之间是怎么拼接的了。先来看U-Net的中间块。它其实就是一个ResBlock接一个SpatialTransformer再接一个ResBlock

  1. self.middle_block = TimestepEmbedSequential(
  2.     ResBlock(...),
  3.     SpatialTransformer(...),
  4.     ResBlock(...),
  5. )

下采样块input_blocks和上采样块output_blocks的结构几乎一模一样,区别只在于每一大层最后是做下采样还是上采样。这里我们以下采样块为例来学习一下这两个块的结构。

  1. self.input_blocks = nn.ModuleList(
  2.     [
  3.         TimestepEmbedSequential(
  4.             conv_nd(dims, in_channels, model_channels, 3, padding=1)
  5.         )
  6.     ]
  7. )
  8. for level, mult in enumerate(channel_mult):
  9.     for _ in range(num_res_blocks):
  10.         layers = [
  11.             ResBlock(...)]
  12.         ch = mult * model_channels
  13.         if ds in attention_resolutions:
  14.                 layers.append(
  15.                 AttentionBlock(...) if not use_spatial_transformer else SpatialTransformer(...))
  16.         self.input_blocks.append(TimestepEmbedSequential(*layers))
  17.     if level != len(channel_mult) - 1:
  18.         out_ch = ch
  19.         self.input_blocks.append(
  20.             TimestepEmbedSequential(
  21.                 ResBlock(...)
  22.                 if resblock_updown
  23.                 else Downsample(...)
  24.             )
  25.         )

上采样块一开始是一个调整输入图片通道数的卷积层,它的作用和self.out输出层一样。

  1. self.input_blocks = nn.ModuleList(
  2.     [
  3.         TimestepEmbedSequential(
  4.             conv_nd(dims, in_channels, model_channels, 3, padding=1)
  5.         )
  6.     ]
  7. )

之后正式进行上采样块的构造。此处代码有两层循环,外层循环表示正在构造哪一个大层,内层循环表示正在构造该大层的哪一组模块。也就是说,共有len(channel_mult)个大层,每一大层都有num_res_blocks组相同的模块。在Stable Diffusion中,channel_mult=[1, 2, 4, 4]num_res_blocks=2

  1. for level, mult in enumerate(channel_mult):
  2.     for _ in range(num_res_blocks):
  3.         ...

每一组模块由一个ResBlock和一个SpatialTransformer构成。

  1. layers = [
  2.     ResBlock(...)
  3. ]
  4. ch = mult * model_channels
  5. if ds in attention_resolutions:
  6.     ...
  7.     layers.append(
  8.         SpatialTransformer(...)
  9.     )
  10. self.input_blocks.append(TimestepEmbedSequential(*layers))
  11. ...

构造完每一组模块后,若现在还没到最后一个大层,则添加一个下采样模块。Stable Diffusion有4个大层,只有运行到前3个大层时才会添加下采样模块。

  1. for level, mult in enumerate(channel_mult):
  2.     for _ in range(num_res_blocks):
  3.         ...
  4.     if level != len(channel_mult) - 1:
  5.         out_ch = ch
  6.         self.input_blocks.append(
  7.             TimestepEmbedSequential(
  8.                 ResBlock(...)
  9.                 if resblock_updown
  10.                 else Downsample(...)
  11.             )
  12.         )
  13.         ch = out_ch
  14.         input_block_chans.append(ch)
  15.         ds *= 2

至此,我们已经学完了Stable Diffusion的U-Net的主要实现代码。让我们来总结一下。U-Net是一种先对数据做下采样,再做上采样的网络结构。为了防止信息丢失,下采样模块和对应的上采样模块之间有残差连接。下采样块、中间块、上采样块都包含了ResBlockSpatialTransformer两种模块。ResBlock是图像网络中常使用的残差块,而SpatialTransformer是能够融合图像全局信息并融合不同模态信息的Transformer块。Stable Diffusion的U-Net的输入除了有图像外,还有时间戳t和约束编码ct会先过几个嵌入层和线性层,再输入进每一个ResBlock中。c会直接输入到所有Transformer块的交叉注意力块中。

Diffusers的源码会在下篇文章中解读。敬请期待!

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/849333
推荐阅读
相关标签
  

闽ICP备14008679号