当前位置:   article > 正文

从Diffusion开始搞懂StableDiffusion_ldm.util.instantiate_from_config

ldm.util.instantiate_from_config

上篇文章介绍了Diffusion扩散模型的原理,这篇文章以此为基础,努力学会StableDiffusion的原理和实现。

原理部分参考大佬的文章,原理从文字和图来看并不复杂,但是当你亲自去实现的时候,就会发现这是一件多么困难的事情,入门易,下手难呀!

Stable Diffusion原理解读 - 知乎 (zhihu.com)

 从SD论文的架构图中,就可以明显看出SD于Diffusion模型的改进之处,

1、首先,SD增加了图片感知压缩(Perceptual Image Compression),也就是左边的红框区域。

图片感知压缩(Perceptual Image Compression)

简单来说,就是原本diffusion模型直接在像素层面训练,内存占用太大,引入感知压缩就是说通过VAE这类自编码模型对原图片进行处理,忽略掉图片中的高频信息,只保留重要、基础的一些特征。这种方法带来的的好处就像引文部分说的一样,能够大幅降低训练和采样阶段的计算复杂度,让文图生成等任务能够在消费级GPU上,在10秒级别时间生成图片,大大降低了落地门槛。

感知压缩主要利用一个预训练的自编码模型,该模型能够学习到一个在感知上等同于图像空间的潜在表示空间。这种方法的一个优势是只需要训练一个通用的自编码模型,就可以用于不同的扩散模型的训练,在不同的任务上使用。这样一来,感知压缩的方法除了应用在标准的无条件图片生成外,也可以十分方便的拓展到各种图像到图像(inpainting,super-resolution)和文本到图像(text-to-image)任务上。

由此可知,基于感知压缩的扩散模型的训练本质上是一个两阶段训练的过程,第一阶段需要训练一个自编码器,第二阶段才需要训练扩散模型本身。在第一阶段训练自编码器时,为了避免潜在表示空间出现高度的异化,作者使用了两种正则化方法,一种是KL-reg,另一种是VQ-reg,因此在官方发布的一阶段预训练模型中,会看到KL和VQ两种实现。在Stable Diffusion中主要采用AutoencoderKL这种实现。

 2、由于进行了图像编码,所以原来的diffussion model变成了latent diffusion model

 LDM就是上一篇文章得到的DM的损失函数,这里变成了LLDM,就是讲输入变成了latent的输入,Zt为编码器得到的结果。

3、最后便是SD引入了Conditioning Mechanisms,也就是输入文本等约束生成特定的图像,从图中可以看出,条件约束通过引入了一个领域专用编码器(domain specific encoder)

并且是在diffusion的逆向推理过程的Unet中增加了cross-attention机制来实现

简单来说就是加了一个transformer里的一个自注意力机制Q是原Zt,K和V由条件输入决定,损失函数变成了LLDM 

 Stable Diffusion 原理介绍与源码分析(一) - 知乎 (zhihu.com)

这里我们从三个模块的源码开始学习

首先是Encode_First_Stage,也就是将图像映射到隐藏层,这里引用珍妮大佬的图 说明这个过程发生了什么

 在SD源码中,位于img2img.py下,有以下代码调用此过程

  1. init_image = load_img(opt.init_img).to(device)
  2. init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
  3. init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space

 而在最新版的SD中其默认的推理config为

  1. parser.add_argument(
  2. "--config",
  3. type=str,
  4. default="configs/stable-diffusion/v2-inference.yaml",
  5. help="path to config which constructs model",
  6. )

在此config中,我们可以找到该层相关参数

  1. first_stage_config:
  2. target: ldm.models.autoencoder.AutoencoderKL
  3. params:
  4. embed_dim: 4
  5. monitor: val/rec_loss
  6. ddconfig:
  7. #attn_type: "vanilla-xformers"
  8. double_z: true
  9. z_channels: 4
  10. resolution: 256
  11. in_channels: 3
  12. out_ch: 3
  13. ch: 128
  14. ch_mult:
  15. - 1
  16. - 2
  17. - 4
  18. - 4
  19. num_res_blocks: 2
  20. attn_resolutions: []
  21. dropout: 0.0
  22. lossconfig:
  23. target: torch.nn.Identity

加载模型的代码如下,可见模型是通过ldm.util中的instantiate_from_config函数加载的

  1. from ldm.util import instantiate_from_config
  2. def load_model_from_config(config, ckpt, verbose=False):
  3. print(f"Loading model from {ckpt}")
  4. pl_sd = torch.load(ckpt, map_location="cpu")
  5. if "global_step" in pl_sd:
  6. print(f"Global Step: {pl_sd['global_step']}")
  7. sd = pl_sd["state_dict"]
  8. model = instantiate_from_config(config.model)
  9. m, u = model.load_state_dict(sd, strict=False)
  10. if len(m) > 0 and verbose:
  11. print("missing keys:")
  12. print(m)
  13. if len(u) > 0 and verbose:
  14. print("unexpected keys:")
  15. print(u)
  16. model.cuda()
  17. model.eval()
  18. return model

这里我们直接找到ldm.model中的autoencoder.py,即可找到这一层的源码,然后我们直接找到AutoencoderKL中的前向代码,可见其就是将输入编码然后解码,返回编解码结果dec和一个posterior

  1. def forward(self, input, sample_posterior=True):
  2. posterior = self.encode(input)
  3. if sample_posterior:
  4. z = posterior.sample()
  5. else:
  6. z = posterior.mode()
  7. dec = self.decode(z)
  8. return dec, posterior

 其encode和decode代码如下,其中最关键的encoder和decoder引用ldm.modules.diffusionmodules.model

  1. from ldm.modules.diffusionmodules.model import Encoder, Decoder
  2. def encode(self, x):
  3. h = self.encoder(x)
  4. moments = self.quant_conv(h)
  5. posterior = DiagonalGaussianDistribution(moments)
  6. return posterior
  7. def decode(self, z):
  8. z = self.post_quant_conv(z)
  9. dec = self.decoder(z)
  10. return dec

这里简单介绍一下AutoencoderKL,它来自Auto-Encoding Variational Bayes(VAE)这篇论文

AutoencoderKL (huggingface.co)

Auto-Encoding Variational Bayes(VAE) - 知乎 (zhihu.com)

Auto encoder是一种无监督算法,主要用于特征提取或数据降维。其思想非常简单,即输入特征X 经过encoder后抽象为hidden layer z,再将z经过decoder过程重新预测为\hat{X}。 

Auto encoder的目的是提取抽象特征z,其学习过程为最小化损失函数L(X,\hat{X}),

用于惩罚二者之间的差异,假设使用平方损失,则有:

 所以个人猜测AutoencoderKL实际上是用KL散度作为损失函数的AE,不知道对不对哈

VAE的目的:很多时候,当我们数据处理时,会遇到数据量不足的情况,这时我们会考虑使用生成模型生成数据。VAE即在AE的基础上引入变分的思想,使其能够进行数据生成。

而其思路是试图推断和学习有向概率图模型的隐分布z,并通过对z的采样来实现数据生成。

这里直接引用慕容三思大佬的原文,大佬文章真的写得非常通俗易懂 

这里其实和diffusion模型有些类似,具体过程应该是先通过encoder获得隐分布z,然后用类似扩散模型的方法生成\hat{X},然后计算输入输出的KL散度,调整decoder 

 而具体计算的过程也和Diffusion类似,yysy,VAE应该比diffusion早?所以应该是diffusion借鉴了VAE?

 然后是和diffusion类似的损失函数推理过程,同样使用了KL散度计算二者之间的差异,然后我也终于进一步看懂了损失函数的意义,最大化每个独立样本的边缘似然概率之和。而每一个样本的差异则是通过KL散度计算两个分布的差异和一个额外的损失函数构成

 这里要看懂还是得多学习概率论计算才行

 这个过程在具体实现时同样需要The reparameterization trick(重参数化)

VAE的具体步骤如下 

 接着我们来看一下AEKL的代码实现

  1. def encode(self, x):
  2. h = self.encoder(x)
  3. moments = self.quant_conv(h)//self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
  4. posterior = DiagonalGaussianDistribution(moments)
  5. return posterior
  6. def decode(self, z):
  7. z = self.post_quant_conv(z)//self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
  8. dec = self.decoder(z)
  9. return dec
  10. def forward(self, input, sample_posterior=True):
  11. posterior = self.encode(input)
  12. if sample_posterior:
  13. z = posterior.sample()
  14. else:
  15. z = posterior.mode()
  16. dec = self.decode(z)
  17. return dec, posterior

在AEKL的forward过程中,先进行encode编码,再进行sample取样,最后再进行解码,解码部分在SD中并不是跟着编码过程后马上进行的,而是在隐式空间内进行了diffusion操作,所以decoder部分放在后面来讲,我们来看encode具体做了什么:encode同样分三步,第一步是encoder中用多个resnet进行特征提取,然后是quant_conv进一步调整通道数,最后是DiagonalGaussianDistribution计算特征分布的均值、方差、标准差等

Encoder的前向推理步骤如下

  1. def forward(self, x):
  2. # timestep embedding
  3. temb = None
  4. # downsampling
  5. hs = [self.conv_in(x)]//self.conv_in = torch.nn.Conv2d(in_channels,self.ch,kernel_size=3,stride=1,padding=1),一个简单的3*3卷积
  6. for i_level in range(self.num_resolutions)://self.num_resolutions = len(ch_mult),ch_mult=(1,2,4,8),应该是指降采样的次数,ch_mult是降采样的倍数
  7. for i_block in range(self.num_res_blocks)://self.num_res_blocks = num_res_blocks,来自于输入,在config文件中设置值为2
  8. h = self.down[i_level].block[i_block](hs[-1], temb)
  9. if len(self.down[i_level].attn) > 0:
  10. h = self.down[i_level].attn[i_block](h)
  11. hs.append(h)
  12. if i_level != self.num_resolutions-1:
  13. hs.append(self.down[i_level].downsample(hs[-1]))
  14. # middle
  15. h = hs[-1]
  16. h = self.mid.block_1(h, temb)
  17. h = self.mid.attn_1(h)
  18. h = self.mid.block_2(h, temb)
  19. # end
  20. h = self.norm_out(h)
  21. h = nonlinearity(h)
  22. h = self.conv_out(h)
  23. return h

self.down的定义如下,可见其是由ResnetBlock和attn构成的,而v2-inference版本没有用注意力,所以此处的down部分仅有多个ResnetBlock构成

  1. self.down = nn.ModuleList()
  2. for i_level in range(self.num_resolutions):
  3. block = nn.ModuleList()
  4. attn = nn.ModuleList()
  5. block_in = ch*in_ch_mult[i_level]
  6. block_out = ch*ch_mult[i_level]
  7. for i_block in range(self.num_res_blocks):
  8. block.append(ResnetBlock(in_channels=block_in,
  9. out_channels=block_out,
  10. temb_channels=self.temb_ch,
  11. dropout=dropout))
  12. block_in = block_out
  13. if curr_res in attn_resolutions:
  14. attn.append(make_attn(block_in, attn_type=attn_type))
  15. down = nn.Module()
  16. down.block = block
  17. down.attn = attn
  18. if i_level != self.num_resolutions-1:
  19. down.downsample = Downsample(block_in, resamp_with_conv)
  20. curr_res = curr_res // 2
  21. self.down.append(down)
  22. //self.down
  23. ModuleList(
  24. (0): Module(
  25. (block): ModuleList(
  26. (0-1): 2 x ResnetBlock(
  27. (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
  28. (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  29. (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
  30. (dropout): Dropout(p=0.0, inplace=False)
  31. (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  32. )
  33. )
  34. (attn): ModuleList()
  35. (downsample): Downsample(
  36. (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
  37. )
  38. )
  39. (1): Module(
  40. (block): ModuleList(
  41. (0): ResnetBlock(
  42. (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
  43. (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  44. (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
  45. (dropout): Dropout(p=0.0, inplace=False)
  46. (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  47. (nin_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
  48. )
  49. (1): ResnetBlock(
  50. (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
  51. (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  52. (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
  53. (dropout): Dropout(p=0.0, inplace=False)
  54. (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  55. )
  56. )
  57. (attn): ModuleList()
  58. (downsample): Downsample(
  59. (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))
  60. )
  61. )
  62. (2): Module(
  63. (block): ModuleList(
  64. (0): ResnetBlock(
  65. (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
  66. (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  67. (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
  68. (dropout): Dropout(p=0.0, inplace=False)
  69. (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  70. (nin_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
  71. )
  72. (1): ResnetBlock(
  73. (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
  74. (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  75. (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
  76. (dropout): Dropout(p=0.0, inplace=False)
  77. (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  78. )
  79. )
  80. (attn): ModuleList()
  81. (downsample): Downsample(
  82. (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))
  83. )
  84. )
  85. (3): Module(
  86. (block): ModuleList(
  87. (0-1): 2 x ResnetBlock(
  88. (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
  89. (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  90. (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
  91. (dropout): Dropout(p=0.0, inplace=False)
  92. (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  93. )
  94. )
  95. (attn): ModuleList()
  96. )
  97. )

 这是经过降采样之后的结果,00是第一层con_in的结果,随后每三个为一组,对应res0,res1和downsample层,共三组9个,最后两个为一组,因为最后一层没有downsample了,只有res0,res1

 降采样之后是middle和end层,比较简单,就两层ResnetBlock,以及最后归一化然后再卷积一次输出,这样子网络设计的原因我暂不清楚,但具体而言encoder就是一系列Resnet提取图像的特征,最终的输出结果为[2,8,64,64],这里的2是由于在输入encode前对图像进行了复制处理。

  1. # middle
  2. self.mid = nn.Module()
  3. self.mid.block_1 = ResnetBlock(in_channels=block_in,
  4. out_channels=block_in,
  5. temb_channels=self.temb_ch,
  6. dropout=dropout)
  7. self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
  8. self.mid.block_2 = ResnetBlock(in_channels=block_in,
  9. out_channels=block_in,
  10. temb_channels=self.temb_ch,
  11. dropout=dropout)
  12. # end
  13. self.norm_out = Normalize(block_in)
  14. self.conv_out = torch.nn.Conv2d(block_in,
  15. 2*z_channels if double_z else z_channels,
  16. kernel_size=3,
  17. stride=1,
  18. padding=1)

这里有一个求高斯分布的函数DiagonalGaussianDistribution

  1. class DiagonalGaussianDistribution(object):
  2. def __init__(self, parameters, deterministic=False):
  3. self.parameters = parameters
  4. self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)//沿着通道维度将输入拆分为两份,得到mean[2,4,64,64],logvar[2,4,64,64],这里的2是输入原图及其复制得到的
  5. self.logvar = torch.clamp(self.logvar, -30.0, 20.0)//将输入input张量每个元素的范围限制到区间 [min,max],返回结果到一个新张量。具体而言比-30小的就是-30,比20大的就是20
  6. self.deterministic = deterministic//默认是false,这个单词的意思叫做确定性
  7. self.std = torch.exp(0.5 * self.logvar)//计算标准差
  8. self.var = torch.exp(self.logvar)//计算方差
  9. if self.deterministic:
  10. self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
  11. def sample(self):
  12. x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
  13. return x
  14. def kl(self, other=None):
  15. if self.deterministic:
  16. return torch.Tensor([0.])
  17. else:
  18. if other is None:
  19. return 0.5 * torch.sum(torch.pow(self.mean, 2)
  20. + self.var - 1.0 - self.logvar,
  21. dim=[1, 2, 3])
  22. else:
  23. return 0.5 * torch.sum(
  24. torch.pow(self.mean - other.mean, 2) / other.var
  25. + self.var / other.var - 1.0 - self.logvar + other.logvar,
  26. dim=[1, 2, 3])
  27. def nll(self, sample, dims=[1,2,3]):
  28. if self.deterministic:
  29. return torch.Tensor([0.])
  30. logtwopi = np.log(2.0 * np.pi)
  31. return 0.5 * torch.sum(
  32. logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
  33. dim=dims)
  34. def mode(self):
  35. return self.mean

高斯分布返回的结果,这地方感觉有点玄学...为啥chunk之后就是mean,还是这只是一个命名?然后为啥图像经过encode多个resnet卷积之后就是高斯分布? 后续有解释了会更新

 在SD的ddpm中,程序首先调用encode_first_stage也就是AEKL中的encode获得卷积之后的高斯分布,然后在get_first_stage_encoding中对高斯分布进行sample操作,所以最终的结果变为[2,4,64,64],也就是最终返回的隐式表达

    init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space
  1. @torch.no_grad()
  2. def encode_first_stage(self, x):
  3. return self.first_stage_model.encode(x)
  1. def get_first_stage_encoding(self, encoder_posterior):
  2. if isinstance(encoder_posterior, DiagonalGaussianDistribution):
  3. z = encoder_posterior.sample()
  4. elif isinstance(encoder_posterior, torch.Tensor):
  5. z = encoder_posterior
  6. else:
  7. raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
  8. return self.scale_factor * z
  1. def sample(self):
  2. x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
  3. return x

图像进入隐式空间后,根据珍妮大佬的图下一步应当是采样阶段?因为我并没有选择训练,但是如果是这样输入图像的目的是什么?SD确实是文本生成图像的框架,但是这个输入图像让我有点懵逼,而且还有img2img的脚本

DDPM (Denoising Diffusion Probabilistic Models)算法就是diffusion的前向和逆向过程,DDIM(Denoising Diffusion Implicit Models)则是对其速度上的改进,DDPM需要迭代上千次,而DDIM几十次就能有较好的结果,那么书接Diffusion和DDPM,让我们看看DDIM做了什么改进并取得了如此好的效果。参考大佬的DDIM介绍文章

扩散模型之DDIM - 知乎 (zhihu.com)

下面的内容基于科学空间苏剑林大佬的解读,感谢! 

生成扩散模型漫谈(四):DDIM = 高观点DDPM - 科学空间|Scientific Spaces (kexue.fm)

可知DDIM的核心思想就是打破马尔科夫链推导过程,重新定义每一步的推理分布,然后直接取中间步骤,从而定义一个更短的步数的前向过程,加速推导过程,也就是跳步采样的思想。

对于DDPM代码我已经在从0开始搞懂Diffusion扩散模型_fisherisfish的博客-CSDN博客文章中详细的介绍了,所以我们直接来看DDIM,当我第一次看DDIM代码的时候,我非常疑惑,因为DDIM代码里没有前向阶段,只有逆向阶段!再未完全了解DDIM的情况下,我做出了两种猜测1、DDIM沿用了DDPM的训练过程,2、DDIM不需要训练过程。在这里我个人是倾向于猜测1的,我们接着学习。

对于DDIM的整个过程,我们再一次回顾DDPM,DDPM的整体流程可以用下面式子来简化, 

注意式中前向阶段的p应当是q,其中是人为设定的高斯分布

前向阶段的损失函数如下 

我们可以看出对于训练过程,

而在此过程中,有一个问题就是推导采样阶段的贝叶斯公式时,

 是依赖于,如果我们不知道此过程,可否解出

 概率论中的知识表示这其实是可以的

 而且根据之前的结论,因为前向过程是正太分布,所以反向过程也应该是正太分布,

结论2:如果 q(Xt|X(t−1)) 满足高斯分布且方差β足够小,则q(X(t−1)|Xt)仍然是一个高斯分布。

 所以我们一般性地假设其为如下一个正太分布

我们为了不重新训练DDPM的前向阶段,所以保持DDPM的基础来求解这一过程

 可得到以下两个方程,即满足以下两个方程,就可以用直接推导并进一步得到采样阶段的计算公式

 求解可得以下结果,注意我们在假设公式4的时候,设定了三个未知量,而我们只有两个方程,所以求解时将\sigma _{t}当做已知量,表达另外两个式子,解如下

代入式4可得,再次注意在DDPM中\beta _{t} 是人为设定的一组极小的线性值,其中DDPM设定\beta _{0}=0,从1开始定义为是由0.0001 到0.02线性插值(插值数由T决定),在DDPM中\alpha _{t}=1-\beta _{t}\bar{\alpha _{t}}=\Pi _{1}^{t}\alpha_{t},故在这个式子中仅有\sigma _{t}仍是未知的,

请注意:苏佬文章定义的\beta _{t}是分布的标准差,而原论文中定义的\beta _{t}是方差,需要注意这一点差别,且苏佬文章中的\alpha _{t}也是DDPM中的开平方值,苏佬文章中还有一个\bar{\beta _{t}},并不是值\beta _{t}的累乘,而是\bar{\beta _{t}}=\sqrt{1-\bar{\alpha _{t}}}

 总结:现在我们在只给定p(xt|x0)、p(xt−1|x0)的情况下,通过待定系数法求解了p(xt−1|xt,x0)的一簇解,它带有一个自由参数σt。

我们的最终目标是得到采样公式,所以我们需要计算X0,来去除中的X0

 回顾DDPM的我们可以知道X0计算公式如下,注意这里论文中的\sqrt{\bar{\alpha _{t}}}等同于蓝色公式(苏佬文章)中的\bar{\alpha _{t}},原因是苏佬定义的\beta _{t}是分布的标准差,而原论文中定义的\beta _{t}是方差,需要注意这一点差别

在苏佬的文章中,用下式来表达X0,式中的噪声项,\epsilon _{\theta }是由Xt和t作为输入,由Unet估计出的噪声项

将式9作为X0,代入 式7就可得

式中 ,到了这一步我们会发现只需要通过定义\sigma _{t}就实现了从Xt到X(t-1)的计算,这里和DDPM的公式做比较,

 我们会发现当取,DDIM的推导公式就是DDPM的,注意此处的理解是DDPM是DDIM的一个特殊情况!

当我们将\sigma _{t}取0时,从Xt到X(t-1)就变成了一个固定的计算公式

总结:这也是DDIM论文中特别关心的一个例子,准确来说,原论文的DDIM就是特指σt=0的情形,其中“I”的含义就是“Implicit”,意思这是一个隐式的概率模型,因为跟其他选择所不同的是,此时从给定的xT=z出发,得到的生成结果x0是不带随机性的。后面我们将会看到,这在理论上和实用上都带来了一些好处。

 那么回归到核心问题,DDIM如何加速采样过程?核心在于跳步,这里提出一个观点就是

 DDPM的训练结果实质上包含了它的任意子序列参数的训练结果。

具体来说,我们训练了从[0,1,2...,T]的DDPM,那么[0~T]中的任意子序列参数的步骤也被训练了,这很好理解。

这里笔者后续又进一步加深了理解,主要是基于之前的计算公式,DDPM公式的推导是基于这个过程的,而DDIM的公式推导的目的就是排除这个过程,仅基于来计算,也就是说DDPM的公式是没办法跳步的,因为每一项的计算都依赖于相邻项的分布关系,但是DDIM通过忽视,直接从入手,而是基于人为假设的高斯分布,基于此实现的逆向公式就是可以跳步的,它并不依赖于相邻值的分布,只依赖于对于X0的分布,而对于X0的分布是已知的,所以我们就可以通过跳步来计算DDIM的采样公式!

那么当我们有一个已经训练好的T步的DDPM,我们从中取子序列来做采样阶段就可以了,假设该子序列有dim(\tau)步,那么其参数就是α¯τ1,α¯τ2,⋯,α¯dim(τ),其采样阶段也就只有dim(\tau)步了

 这就是加速的方法,那么问题又来了,为什么要训练一个T步长的DDPM呢?直接训练一个dim(\tau)步不就行了?这里苏佬给出了两点解释

1、训练更多步数的模型也许能增强泛化能力;2、通过子序列进行加速只是其中一种加速手段,训练更充分的T步允许我们尝试更多的其他加速手段,但并不会显著增加训练成本。

到这一步,我们先解答一下之前的问题,DDIM只是一种用于采样阶段的加速方法,其加速的思想是跳步,这里笔者心中遗留了一个问题就是子序列到底是怎么取的?还有自由变量\sigma _{t}的具体作用?所以我们进入SD的DDIM实现,来进一步一探究竟!

以SD的img2img.py脚本为例,在模型创建以后,就设定为DDIMSampler,前文中,我们介绍了SD的编码部分,但值得一提的是我们使用SD时,仅使用了采样阶段,所以我们其实没有训练时的编码,我们更多是对输入图像和文字进行编码,输入图像首先通过AEKL编码进入隐式表达,然后作为X0用前向公式计算出Xt,基于此Xt,再对文字用Clip进行编码,作为条件输入并入Unet的生成噪音阶段,通过噪音不断逆向计算出另一个X0,此X0同时基于文字和图像两个条件生成,这是SDimg2img的原理,而对于txt2img,Xt则是随机生成的高斯噪声,其余步骤则是类似的。

  1. sampler = DDIMSampler(model)
  2. class DDIMSampler(object):
  3. def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
  4. super().__init__()
  5. self.model = model
  6. self.ddpm_num_timesteps = model.num_timesteps
  7. self.schedule = schedule
  8. self.device = device

我们接着看SD的运行过程,在img2img中,首先对图像进行了编码,然后初始化采样器

  1. init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
  2. sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)

在make_schedule(附表)中,就是和DDPM中一样的初始化各种参数的过程,我们先来看一下ddim_timesteps,也就是采样的时间步t,根据之前的分析这肯定是和DDPM是不一样的,在 DDPM中直接就是设定时间步是1000,然后基于此去初始化β,而在 DDIM中,默认的取步方法是uniform(均匀),简单来说,设定ddim的timesteps数,默认是50,则相比DDPM少了20倍,加速效果也是理所应当的,所以DDIM的timesteps就是[0,20,40,...,960,980],值得注意的是这里对每一步都+1,所以最后结果是[1,21,41,...,961,981],作者给出的原因是

to get the final alpha values right (the ones from first scale to data during sampling)
  1. def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
  2. self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
  3. num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
  4. def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
  5. if ddim_discr_method == 'uniform':
  6. c = num_ddpm_timesteps // num_ddim_timesteps
  7. ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
  8. elif ddim_discr_method == 'quad':
  9. ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
  10. else:
  11. raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
  12. # assert ddim_timesteps.shape[0] == num_ddim_timesteps
  13. # add one to get the final alpha values right (the ones from first scale to data during sampling)
  14. steps_out = ddim_timesteps + 1
  15. if verbose:
  16. print(f'Selected timesteps for ddim sampler: {steps_out}')
  17. return steps_out

 接着是初始化\bar{\alpha }和β值,让我比较疑惑的是这里输入的\bar{\alpha }和β是从SD的v1.4.ckpt读入的,而其值并不是线性分布的,而且\bar{\alpha }值似乎也不是正常的1-β来的线性分布值累乘来的[0.9991, 0.9983, 0.9974,...,0.0047],0.9991=1*(1-0.0009),0.9983!=0.9991*(0.9991-0.0009),所以SD的v1.4版本应该不是通过线性插值来获取这两个值,或者说β不是,而\bar{\alpha }是计算误差?毕竟公式应该不会错。

  1. alphas_cumprod = self.model.alphas_cumprod
  2. assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
  3. to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
  4. self.register_buffer('betas', to_torch(self.model.betas))
  5. self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
  6. self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

这一部分是前向计算和ddpm中一模一样从上到下分别是\sqrt{\bar{\alpha }}\sqrt{1-\bar{\alpha }}log({1-\bar{\alpha }})\frac{1}{\sqrt{\bar{\alpha }}}\frac{\sqrt{1-\bar{\alpha }}}{\sqrt{\bar{\alpha }}}

  1. # calculations for diffusion q(x_t | x_{t-1}) and others
  2. self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
  3. self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
  4. self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
  5. self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
  6. self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))

采样阶段的值则出现了明显的不同, make_ddim_sampling_parameters输入是\bar{\alpha },ddim的采样步数,eta默认是0。我们可以看到在此函数中,首先根据ddim_timesteps对\bar{\alpha }取值,即取下标为[1,21,41,...,961,981]的这些值,此处+1的效果就显现了,+1之后可以再往前取出相应的alphas_prev值了,保证了计算链的平衡。

ddim这里有一个计算sigmas的公式,和DDPM中似乎有一些出入,原因是在DDIM的论文里,α符号指代的是DDPM中的\bar{\alpha },这里从代码中也能看出,输入的是alphacums即α的累乘值,此处的sigma计算公式如下

 和前文的符号不同外,本质是一样的,这里再回顾一下sigma是p(X(t-1)|Xt)的标准差,如果取上式的值则和DDPM中的p(X(t-1)|Xt,X0)是一样的,在DDIM中给sigma加了一个系数值eta,当eta取1时,即DDPM,而DDIM给eta取0,此时的p(X(t-1)|Xt)是一个固定的式子

所以返回后的 ddim_alphas, ddim_alphas_prev实际上应该是ddim_alphas_cum, ddim_alphas_prev_cum

  1. # ddim sampling parameters
  2. ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
  3. ddim_timesteps=self.ddim_timesteps,
  4. eta=ddim_eta,verbose=verbose)
  5. def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
  6. # select alphas for computing the variance schedule
  7. alphas = alphacums[ddim_timesteps]
  8. alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
  9. # according the the formula provided in https://arxiv.org/abs/2010.02502
  10. sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
  11. if verbose:
  12. print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
  13. print(f'For the chosen value of eta, which is {eta}, '
  14. f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
  15. return sigmas, alphas, alphas_prev

继续初始化参数值, ddim_sqrt_one_minus_alphas实际上就是\bar{\beta _{t}},值得一提的是这里还计算了DDPM的sigma值,计算公式和ddim一致,也乘了eta,也就是0,具体作用后续再看

  1. self.register_buffer('ddim_sigmas', ddim_sigmas)
  2. self.register_buffer('ddim_alphas', ddim_alphas)
  3. self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
  4. self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
  5. sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
  6. (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
  7. 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
  8. self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

初始化参数值之后, 是对输入的随机编码,输入为输入图像编码后的隐式表达init_latent,还有取样的时间步t_enc=40

  1. t_enc = int(opt.strength * opt.ddim_steps)
  2. # encode (scaled latent)
  3. z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))

 我们通过回顾DDPM的公式,可知此处是通过X0,以及模型中的α值来初始化X40的值,而X40就是逆向阶段的起始高斯噪音图。

  1. @torch.no_grad()
  2. def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
  3. # fast, but does not allow for exact reconstruction
  4. # t serves as an index to gather the correct alphas
  5. if use_original_steps:
  6. sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
  7. sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
  8. else:
  9. sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
  10. sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
  11. if noise is None:
  12. noise = torch.randn_like(x0)
  13. return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
  14. extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)

在采样阶段(逆向阶段)输入就是X40(z_enc),prompt编码后的特征值(c),时间步t_enc,后续的unconditional_guidance_scale和unconditional_conditioning涉及到基础知识classifier-free guidance

  1. # decode it
  2. samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
  3. unconditional_conditioning=uc, )

 先从Classifier Guidance介绍,来自于论文Diffusion Models Beat GANs on Image Synthesis,主要作用是使得扩散模型能够按类生成,具体而言就是借用别人训练好的扩散模型,我们自己再训练一个分类器,通过分类器来指定扩散模型的扩散过程。推荐苏佬的讲解,后文是本人的粗鄙理解,请见笑

生成扩散模型漫谈(九):条件控制生成结果 - 科学空间|Scientific Spaces (kexue.fm)

在DDPM的基础之上,增加条件y,那么逆向公式修正为

要计算这个式子,我们一步步来看,首先是P(X(t-1)|y)

我们用一次简单的贝叶斯公式,这个式子对于概率论新手也是友好的,hhh

接着对每一项,都补上条件Xt,这比直接对贝叶斯公式友好太多了,感谢苏佬的讲解。

 注意,在前向过程中,xt是由x(t−1)加噪声得到的,噪声不会对分类有帮助,所以xt的加入对分类不会有任何收益,因此有p(y|x(t−1),xt)=p(y|x(t−1)),从而

 接着对右指数进行一次泰勒展开,这个地方泰勒展开就不详细讲了

在DDPM中,我们假设逆向过程也是高斯分布如下:

 

 那么加上y以后的表达式结合式3、4如下

以此为结果,我们可以得到结论 

 

 那么我们就可以得到X(t-1)的计算公式,新增项如下

 此处和原论文略有差别,差别如下

 然后是经典的引入参数\gamma来调节条件参数的影响大小

 当γ>1时,生成过程将使用更多的分类器信号,结果将会提高生成结果与输入信号γ的相关性,但是会相应地降低生成结果的多样性;反之,则会降低生成结果与输入信号之间的相关性,但增加了多样性。

 《More Control for Free! Image Synthesis with Semantic Diffusion Guidance》论文对γ进行了更多的解释,后续有机会进一步学习,但对于SD我们理解到这一步应该是够了

对于Classifier-Free方案,来自于论文《Classifier-Free Diffusion Guidance》,它的思想是直接将条件作为模型输入之一,用来生成噪音

 训练的损失函数就是

 从DDPM一路学过来就会发现这些公式本质上就是加了y作为模型输入,

后续这部分缩放机制属于是作者个人的创新想法了,最后会浓缩为一个公式

式中α就是公式(24)中的\omega ,那么理解上来说,可以把c当作正向的prompt,\phi作为反向的prompt(negative prompt),也就是unconditional_conditioning,unconditional_guidance_scale就是α的大小

我们再回到代码部分,初版的SD,uc并非是通过人为输入的negative prompt,而是直接生成的

  1. uc = None
  2. if opt.scale != 1.0:
  3. uc = model.get_learned_conditioning(batch_size * [""])
  4. if isinstance(prompts, tuple):
  5. prompts = list(prompts)
  6. c = model.get_learned_conditioning(prompts)
  7. # decode it
  8. samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
  9. unconditional_conditioning=uc, )

decode函数相对比较简单,循环调用p_sample_ddim, 

  1. @torch.no_grad()
  2. def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
  3. use_original_steps=False, callback=None):
  4. timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
  5. timesteps = timesteps[:t_start]
  6. time_range = np.flip(timesteps)
  7. total_steps = timesteps.shape[0]
  8. print(f"Running DDIM Sampling with {total_steps} timesteps")
  9. iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
  10. x_dec = x_latent
  11. for i, step in enumerate(iterator):
  12. index = total_steps - i - 1
  13. ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
  14. x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
  15. unconditional_guidance_scale=unconditional_guidance_scale,
  16. unconditional_conditioning=unconditional_conditioning)
  17. if callback: callback(i)
  18. return x_dec

p_sample_ddim的代码看起来复杂,我们一段段来看,首先是前半部分由于SD是有uc输入的,所以第一个if运行else的内容,else里首先将输入的x40,和时间步t乘以2,这里讲一下t和index的区别,index实际上是DDIM的标签,比如第一次循环是从X40,生成X39,那么index值就是39,而t是原始DDPM的时间步,此时值是781,而我们的模型是DDPM训练的,所以当输入Unet计算噪音的时候,需要这个t值。乘以2的原因是需要将c和uc相结合,所以输入的X和T也乘2,来保证一致性。x_in大小[4,4,64,64],t_in大小[4,],c_in大小[4,77,768]

  1. @torch.no_grad()
  2. def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
  3. temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
  4. unconditional_guidance_scale=1., unconditional_conditioning=None,
  5. dynamic_threshold=None):
  6. b, *_, device = *x.shape, x.device
  7. if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
  8. model_output = self.model.apply_model(x, t, c)
  9. else:
  10. x_in = torch.cat([x] * 2)
  11. t_in = torch.cat([t] * 2)
  12. if isinstance(c, dict):
  13. assert isinstance(unconditional_conditioning, dict)
  14. c_in = dict()
  15. for k in c:
  16. if isinstance(c[k], list):
  17. c_in[k] = [torch.cat([
  18. unconditional_conditioning[k][i],
  19. c[k][i]]) for i in range(len(c[k]))]
  20. else:
  21. c_in[k] = torch.cat([
  22. unconditional_conditioning[k],
  23. c[k]])
  24. elif isinstance(c, list):
  25. c_in = list()
  26. assert isinstance(unconditional_conditioning, list)
  27. for i in range(len(c)):
  28. c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
  29. else:
  30. c_in = torch.cat([unconditional_conditioning, c])

借着就是作为输入预估噪声的步骤 

  1. model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
  2. model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)

apply_model函数很简单,对cond做了一个增加c_crossattn的字典,然后调用model函数 

  1. def apply_model(self, x_noisy, t, cond, return_ids=False):
  2. if isinstance(cond, dict):
  3. # hybrid case, cond is expected to be a dict
  4. pass
  5. else:
  6. if not isinstance(cond, list):
  7. cond = [cond]
  8. key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
  9. cond = {key: cond}
  10. x_recon = self.model(x_noisy, t, **cond)
  11. if isinstance(x_recon, tuple) and not return_ids:
  12. return x_recon[0]
  13. else:
  14. return x_recon

 调用模型是DDPM中的DiffusionWrapper,forward代码如下,由于选择了c_crossattn所以核心代码就是elif self.conditioning_key == 'crossattn':那几行,具体而言就是cc = torch.cat(c_crossattn, 1)这一步把c_crossattn的dict又变成了cc这个tensor,大小和值均不变,是[4,77,768],然后就是输入进 out = self.diffusion_model(x, t, context=cc)

  1. class DiffusionWrapper(pl.LightningModule):
  2. def __init__(self, diff_model_config, conditioning_key):
  3. super().__init__()
  4. self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
  5. self.diffusion_model = instantiate_from_config(diff_model_config)
  6. self.conditioning_key = conditioning_key
  7. assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
  8. def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
  9. if self.conditioning_key is None:
  10. out = self.diffusion_model(x, t)
  11. elif self.conditioning_key == 'concat':
  12. xc = torch.cat([x] + c_concat, dim=1)
  13. out = self.diffusion_model(xc, t)
  14. elif self.conditioning_key == 'crossattn':
  15. if not self.sequential_cross_attn:
  16. cc = torch.cat(c_crossattn, 1)
  17. else:
  18. cc = c_crossattn
  19. if hasattr(self, "scripted_diffusion_model"):
  20. # TorchScript changes names of the arguments
  21. # with argument cc defined as context=cc scripted model will produce
  22. # an error: RuntimeError: forward() is missing value for argument 'argument_3'.
  23. out = self.scripted_diffusion_model(x, t, cc)
  24. else:
  25. out = self.diffusion_model(x, t, context=cc)
  26. elif self.conditioning_key == 'hybrid':
  27. xc = torch.cat([x] + c_concat, dim=1)
  28. cc = torch.cat(c_crossattn, 1)
  29. out = self.diffusion_model(xc, t, context=cc)
  30. elif self.conditioning_key == 'hybrid-adm':
  31. assert c_adm is not None
  32. xc = torch.cat([x] + c_concat, dim=1)
  33. cc = torch.cat(c_crossattn, 1)
  34. out = self.diffusion_model(xc, t, context=cc, y=c_adm)
  35. elif self.conditioning_key == 'crossattn-adm':
  36. assert c_adm is not None
  37. cc = torch.cat(c_crossattn, 1)
  38. out = self.diffusion_model(x, t, context=cc, y=c_adm)
  39. elif self.conditioning_key == 'adm':
  40. cc = c_crossattn[0]
  41. out = self.diffusion_model(x, t, y=cc)
  42. else:
  43. raise NotImplementedError()
  44. return out

这里的diffusionmodel值得是modules/diffusionmodules/openaimodel.py文件中的class UNetModel(nn.Module),我们先看forwrad过程,作者注释了各个输入的含义,由于我们采用的Classifier-Free方案所以y是None

  1. def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
  2. """
  3. Apply the model to an input batch.
  4. :param x: an [N x C x ...] Tensor of inputs.
  5. :param timesteps: a 1-D batch of timesteps.
  6. :param context: conditioning plugged in via crossattn
  7. :param y: an [N] Tensor of labels, if class-conditional.
  8. :return: an [N x C x ...] Tensor of outputs.
  9. """

这里的timestep_embedding函数里出现了model_channels参数,该参数由Unet的config文件给出,在SD中一直是320, timestep_embedding的具体作用是对timesteps编码,最后结果是embedding[4,320],其中每一行值都是一样的,编码方式是sinusoidal timestep embeddings,即三角函数编码,具体计算过程看代码就明白了,编码完以后还有time_embed也就是两个线性层中间一个激活层,也就是一层MLP,激活函数是SiLU,最后的结果是[4,1280],其中1280是time_embed_dim,定义为model_channel的4倍。

Time Embedding的使用可以帮助深度学习模型更好地理解时间相关性,从而提高模型的性能。比如在Stable Diffusion中,将Time Embedding引入U-Net中,帮助其在扩散过程中从容预测噪声。

Stable Diffusion需要迭代多次对噪音进行逐步预测,使用Time Embedding就可以将time编码到网络中,从而在每一次迭代中让U-Net更加合适的噪声预测

图源见图中英文,图像复制过来就有水印,抱歉!深入浅出解析Stable Diffusion中U-Net的核心知识与价值 | 【算法兵器谱】 - 知乎 (zhihu.com)

  1. hs = []
  2. t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
  3. emb = self.time_embed(t_emb)
  4. def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
  5. """
  6. Create sinusoidal timestep embeddings.
  7. :param timesteps: a 1-D Tensor of N indices, one per batch element.
  8. These may be fractional.
  9. :param dim: the dimension of the output.
  10. :param max_period: controls the minimum frequency of the embeddings.
  11. :return: an [N x dim] Tensor of positional embeddings.
  12. """
  13. if not repeat_only:
  14. half = dim // 2
  15. freqs = torch.exp(
  16. -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
  17. ).to(device=timesteps.device)
  18. args = timesteps[:, None].float() * freqs[None]
  19. embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  20. if dim % 2:
  21. embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
  22. else:
  23. embedding = repeat(timesteps, 'b -> b d', d=dim)
  24. return embedding
  25. time_embed_dim = model_channels * 4
  26. self.time_embed = nn.Sequential(
  27. linear(model_channels, time_embed_dim),
  28. nn.SiLU(),
  29. linear(time_embed_dim, time_embed_dim),
  30. )

 接下来的步骤和AEKL其实有点像,我们重点关注对于输入值的处理,这里的input_blocks我们需要详细看一下其组成

  1. h = x.type(self.dtype)
  2. for module in self.input_blocks:
  3. h = module(h, emb, context)
  4. hs.append(h)

 在模型初始化中,input_blocks首先加了一个TimestepEmbedSequential,这玩意输入里有emb,emb是时间步t的编码,但是由于layer就是一个简单的conv_nd,这里的dims默认是2,就是一个conv_2d,in_channels是4,和model_channels一样由config决定,不过in_channels需要和输入的x的通道数保持一致,x是[4,4,64,64],注意是第二个通道。

  1. self.input_blocks = nn.ModuleList(
  2. [
  3. TimestepEmbedSequential(
  4. conv_nd(dims, in_channels, model_channels, 3, padding=1)
  5. )
  6. ]
  7. )
  8. class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
  9. """
  10. A sequential module that passes timestep embeddings to the children that
  11. support it as an extra input.
  12. """
  13. def forward(self, x, emb, context=None):
  14. for layer in self:
  15. if isinstance(layer, TimestepBlock):
  16. x = layer(x, emb)
  17. elif isinstance(layer, SpatialTransformer):
  18. x = layer(x, context)
  19. else:
  20. x = layer(x)
  21. return x
  22. def conv_nd(dims, *args, **kwargs):
  23. """
  24. Create a 1D, 2D, or 3D convolution module.
  25. """
  26. if dims == 1:
  27. return nn.Conv1d(*args, **kwargs)
  28. elif dims == 2:
  29. return nn.Conv2d(*args, **kwargs)
  30. elif dims == 3:
  31. return nn.Conv3d(*args, **kwargs)
  32. raise ValueError(f"unsupported dimensions: {dims}")

在config文件中有一个重要参数叫做

channel_mult: [ 1, 2, 4, 4 ]

其主要作用就是给input_blocks循环添加ResBlock以及其他的卷积层

还有一个参数叫做

num_res_blocks:2

也就是每层都要重复2个相同的ResBlock以及其他的卷积层

我们先看每一层的组成,首先就是一个ResBlock,这里直接贴ResBlock的组成部分,注意这里的ResBlock输入有x和emb两个,组合方式是广播加法,即h=h+emb_out,注意h卷积之后shape为[4,320,64,64],emb_out则是[4,320,1,1]

  1. layers = [
  2. ResBlock(
  3. ch,
  4. time_embed_dim,
  5. dropout,
  6. out_channels=mult * model_channels,
  7. dims=dims,
  8. use_checkpoint=use_checkpoint,
  9. use_scale_shift_norm=use_scale_shift_norm,
  10. )
  11. ]
  12. def _forward(self, x, emb):
  13. h = self.in_layers(x)
  14. emb_out = self.emb_layers(emb).type(h.dtype)
  15. while len(emb_out.shape) < len(h.shape):
  16. emb_out = emb_out[..., None]
  17. h = h + emb_out
  18. h = self.out_layers(h)
  19. return self.skip_connection(x) + h
  20. (0): ResBlock(
  21. (in_layers): Sequential(
  22. (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  23. (1): SiLU()
  24. (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  25. )
  26. (h_upd): Identity()
  27. (x_upd): Identity()
  28. (emb_layers): Sequential(
  29. (0): SiLU()
  30. (1): Linear(in_features=1280, out_features=320, bias=True)
  31. )
  32. (out_layers): Sequential(
  33. (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  34. (1): SiLU()
  35. (2): Dropout(p=0, inplace=False)
  36. (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  37. )
  38. (skip_connection): Identity()
  39. )

接着是添加 SpatialTransformer

Cross Attention是一种多头注意力机制,它可以在两个不同的输入序列之间建立关联,并且可以将其中一个输入序列的信息传递给另一个输入序列。

在计算机视觉中,Cross Attention可以用于将图像与文本之间的关联建立。例如,在图像字幕生成任务中,Cross Attention可以将图像中的区域与生成的文字之间建立关联,以便生成更准确的描述。

Stable Diffusion中使用Cross Attention模块控制文本信息和图像信息的融合交互,通俗来说,控制U-Net把噪声矩阵的某一块与文本里的特定信息相对应。注意此图仅为单个CrossAttention结构,并不是SpatialTransformer的结构。

 在实际操作时有两个Cross Attention,第一个CrossAttention并不会将context作为输入,而是对X做自注意力,也就是普通的SA,第二个Cross Attention才加入context,其结构和上图一致,两个attn串联后再经过feedforward层

  1. if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
  2. layers.append(
  3. AttentionBlock(
  4. ch,
  5. use_checkpoint=use_checkpoint,
  6. num_heads=num_heads,
  7. num_head_channels=dim_head,
  8. use_new_attention_order=use_new_attention_order,
  9. ) if not use_spatial_transformer else SpatialTransformer(
  10. ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
  11. disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
  12. use_checkpoint=use_checkpoint
  13. )
  14. )
  15. //CrossAttention代码
  16. class CrossAttention(nn.Module):
  17. def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
  18. super().__init__()
  19. inner_dim = dim_head * heads
  20. context_dim = default(context_dim, query_dim)
  21. self.scale = dim_head ** -0.5
  22. self.heads = heads
  23. self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
  24. self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
  25. self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
  26. self.to_out = nn.Sequential(
  27. nn.Linear(inner_dim, query_dim),
  28. nn.Dropout(dropout)
  29. )
  30. def forward(self, x, context=None, mask=None):
  31. h = self.heads
  32. q = self.to_q(x)
  33. context = default(context, x)
  34. k = self.to_k(context)
  35. v = self.to_v(context)
  36. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  37. # force cast to fp32 to avoid overflowing
  38. if _ATTN_PRECISION =="fp32":
  39. with torch.autocast(enabled=False, device_type = 'cuda'):
  40. q, k = q.float(), k.float()
  41. sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  42. else:
  43. sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  44. del q, k
  45. if exists(mask):
  46. mask = rearrange(mask, 'b ... -> b (...)')
  47. max_neg_value = -torch.finfo(sim.dtype).max
  48. mask = repeat(mask, 'b j -> (b h) () j', h=h)
  49. sim.masked_fill_(~mask, max_neg_value)
  50. # attention, what we cannot get enough of
  51. sim = sim.softmax(dim=-1)
  52. out = einsum('b i j, b j d -> b i d', sim, v)
  53. out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
  54. return self.to_out(out)
  55. //2个CrossAttention组成BasicTransformerBlock
  56. class BasicTransformerBlock(nn.Module):
  57. ATTENTION_MODES = {
  58. "softmax": CrossAttention, # vanilla attention
  59. "softmax-xformers": MemoryEfficientCrossAttention
  60. }
  61. def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
  62. disable_self_attn=False):
  63. super().__init__()
  64. attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
  65. assert attn_mode in self.ATTENTION_MODES
  66. attn_cls = self.ATTENTION_MODES[attn_mode]
  67. self.disable_self_attn = disable_self_attn
  68. self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
  69. context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
  70. self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
  71. self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
  72. heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
  73. self.norm1 = nn.LayerNorm(dim)
  74. self.norm2 = nn.LayerNorm(dim)
  75. self.norm3 = nn.LayerNorm(dim)
  76. self.checkpoint = checkpoint
  77. def forward(self, x, context=None):
  78. return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
  79. def _forward(self, x, context=None):
  80. x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
  81. x = self.attn2(self.norm2(x), context=context) + x
  82. x = self.ff(self.norm3(x)) + x
  83. return x
  84. //SpatialTransformer的forward,和普通的Transformer类似,其中的block只有一个BasicTransformerBlock
  85. def forward(self, x, context=None):
  86. # note: if no context is given, cross-attention defaults to self-attention
  87. if not isinstance(context, list):
  88. context = [context]
  89. b, c, h, w = x.shape
  90. x_in = x
  91. x = self.norm(x)
  92. if not self.use_linear:
  93. x = self.proj_in(x)
  94. x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
  95. if self.use_linear:
  96. x = self.proj_in(x)
  97. for i, block in enumerate(self.transformer_blocks):
  98. x = block(x, context=context[i])
  99. if self.use_linear:
  100. x = self.proj_out(x)
  101. x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
  102. if not self.use_linear:
  103. x = self.proj_out(x)
  104. return x + x_in
  105. (1): SpatialTransformer(
  106. (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
  107. (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  108. (transformer_blocks): ModuleList(
  109. (0): BasicTransformerBlock(
  110. (attn1): CrossAttention(
  111. (to_q): Linear(in_features=320, out_features=320, bias=False)
  112. (to_k): Linear(in_features=320, out_features=320, bias=False)
  113. (to_v): Linear(in_features=320, out_features=320, bias=False)
  114. (to_out): Sequential(
  115. (0): Linear(in_features=320, out_features=320, bias=True)
  116. (1): Dropout(p=0.0, inplace=False)
  117. )
  118. )
  119. (ff): FeedForward(
  120. (net): Sequential(
  121. (0): GEGLU(
  122. (proj): Linear(in_features=320, out_features=2560, bias=True)
  123. )
  124. (1): Dropout(p=0.0, inplace=False)
  125. (2): Linear(in_features=1280, out_features=320, bias=True)
  126. )
  127. )
  128. (attn2): CrossAttention(
  129. (to_q): Linear(in_features=320, out_features=320, bias=False)
  130. (to_k): Linear(in_features=768, out_features=320, bias=False)
  131. (to_v): Linear(in_features=768, out_features=320, bias=False)
  132. (to_out): Sequential(
  133. (0): Linear(in_features=320, out_features=320, bias=True)
  134. (1): Dropout(p=0.0, inplace=False)
  135. )
  136. )
  137. (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  138. (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  139. (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  140. )
  141. )
  142. (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  143. )
  144. )

我们先总结一下,此时一个ResNet后接一个SpatialTransformer,然后该组件根据num_res_blocks

需要循环两次接入模型中,其中ResNet的输入是X和emb,也就是输入图像和时间之间特征,SpatialTransformer则是在此特征上再加上了context,至此三个输入之间都建立起了联系,我们也搞明白了是如何处理的,最后就是根据Unet的结构搭建整个模型了,中间还有降采样和上采样过程

  1. (1-2): 2 x TimestepEmbedSequential(
  2. (0): ResBlock(
  3. (in_layers): Sequential(
  4. (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  5. (1): SiLU()
  6. (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  7. )
  8. (h_upd): Identity()
  9. (x_upd): Identity()
  10. (emb_layers): Sequential(
  11. (0): SiLU()
  12. (1): Linear(in_features=1280, out_features=320, bias=True)
  13. )
  14. (out_layers): Sequential(
  15. (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
  16. (1): SiLU()
  17. (2): Dropout(p=0, inplace=False)
  18. (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  19. )
  20. (skip_connection): Identity()
  21. )
  22. (1): SpatialTransformer(
  23. (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
  24. (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  25. (transformer_blocks): ModuleList(
  26. (0): BasicTransformerBlock(
  27. (attn1): CrossAttention(
  28. (to_q): Linear(in_features=320, out_features=320, bias=False)
  29. (to_k): Linear(in_features=320, out_features=320, bias=False)
  30. (to_v): Linear(in_features=320, out_features=320, bias=False)
  31. (to_out): Sequential(
  32. (0): Linear(in_features=320, out_features=320, bias=True)
  33. (1): Dropout(p=0.0, inplace=False)
  34. )
  35. )
  36. (ff): FeedForward(
  37. (net): Sequential(
  38. (0): GEGLU(
  39. (proj): Linear(in_features=320, out_features=2560, bias=True)
  40. )
  41. (1): Dropout(p=0.0, inplace=False)
  42. (2): Linear(in_features=1280, out_features=320, bias=True)
  43. )
  44. )
  45. (attn2): CrossAttention(
  46. (to_q): Linear(in_features=320, out_features=320, bias=False)
  47. (to_k): Linear(in_features=768, out_features=320, bias=False)
  48. (to_v): Linear(in_features=768, out_features=320, bias=False)
  49. (to_out): Sequential(
  50. (0): Linear(in_features=320, out_features=320, bias=True)
  51. (1): Dropout(p=0.0, inplace=False)
  52. )
  53. )
  54. (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  55. (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  56. (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  57. )
  58. )
  59. (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  60. )
  61. )

整个模型如下图

 我们再回到前面调用模型的步骤,根据之前的推导Unet的输出是噪音,也就是这个x_recon

  1. def apply_model(self, x_noisy, t, cond, return_ids=False):
  2. if isinstance(cond, dict):
  3. # hybrid case, cond is expected to be a dict
  4. pass
  5. else:
  6. if not isinstance(cond, list):
  7. cond = [cond]
  8. key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
  9. cond = {key: cond}
  10. x_recon = self.model(x_noisy, t, **cond)
  11. if isinstance(x_recon, tuple) and not return_ids:
  12. return x_recon[0]
  13. else:
  14. return x_recon

注意我们之前在输入的时候,x_in是复制了两份的,所以输出后需要chunk,这里结合之前的c_in是uncond+c,所以此时前两个通道是model_uncond,是,后两个通道才是我们心心念念的model_t,也就是 

model_output就是下面这个公式

  1. model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
  2. model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)

再之后就是反向计算X(t-1)了,在classier-free中,计算公式是和DDPM\DDIM一致的,

在代码中,首先计算了X0,依据公式是

然后在基于X0,计算X(t-1),依据公式和上面是一样的,但是后面的噪音部分有变化,具体而言就是重新加上了sigma_t,不过在DDIM中,sigma_t是取0的,所以还是一样的....,相当于增加一些随机性吧,这里其实有一点疑惑在于既然可以直接计算X0了,为什么还要一步步计算X(t-1),个人的猜测是直接计算X0,效果可能非常差,还有一个可能是没有随机性,生成图像和输入图像是一样的

  1. pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  2. # direction pointing to x_t
  3. dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
  4. noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
  5. if noise_dropout > 0.:
  6. noise = torch.nn.functional.dropout(noise, p=noise_dropout)
  7. x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise

 后面的这些步骤其实就简单了,有些人可能会疑惑为啥是samples而不是sample,因为最开始的输入的z_enc就是[2,4,64,64],所以最后的samples也是[2,4,64,64],解码之后就会生成两张图像,而且还是一正一反的,这个反感觉就是model_uncond给出的

  1. # decode it
  2. samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
  3. unconditional_conditioning=uc, )
  4. x_samples = model.decode_first_stage(samples)
  5. x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
  6. for x_sample in x_samples:
  7. x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
  8. img = Image.fromarray(x_sample.astype(np.uint8))
  9. img = put_watermark(img, wm_encoder)
  10. img.save(os.path.join(sample_path, f"{base_count:05}.png"))
  11. base_count += 1
  12. all_samples.append(x_samples)

 总结:这篇文章到这里已经4W字了,可能中间有非常多没有必要的内容,但对于一个初学者来说,想要完全搞懂SD这些内容可能都是不够的,对于我个人来说,写完以后,对于SD有了非常清晰的认识,解决了非常多的困惑,学会了很多新知识,下一篇文章将是ControlNet

在训练阶段,SD首先对输入图像进行AEKL编码,SD的AEKL编码由许多ResNet块组成进行特征提取并编码,并经过一个求高斯分布的函数DiagonalGaussianDistribution进行高斯采样,最终将输入图像[2,3,512,512],编码为[2,4,64,64],后续的训练阶段同DDPM,进行前向推导,计算KL散度损失函数,优化Unet模型,在训练过程中,同样有prompts编码的c以及时间步t作为输入。

在采样阶段,对于txt2img随机生成初始的高斯噪音图像,对于img2img,输入的img同样需要AEKL编码进隐式表达,latent_input会根据模型ckpt提供的α值和β值直接计算为X40,作为初始的高斯分布噪音从而影响到图像生成,此外采样阶段采用DDIM方法,进行跳步采样,并且改变了反向过程的计算公式,实现了加速。

SD使用的Unet模型中加入了Time Embedding,作用是编码时间步使之可以加入到Unet作为输入,时间步特征加入到了Unet的ResNet Block中;还加入了crossattention,其是self-attention的变种,可以在两个不同的输入序列之间建立关联,并且可以将其中一个输入序列的信息传递给另一个输入序列,并由crossattention构建SpatialTransformer,其输入了文本编码形成的context信息。

遗留的问题:

1、SD中AEKL和Unet结构的原因,是否是大小和性能的平衡?

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/225746?site
推荐阅读
相关标签
  

闽ICP备14008679号