当前位置:   article > 正文

从SD开始搞懂ControlNet_sd controlnet

sd controlnet

在学习SD后,我个人猜测ControlNet和SD中的Classifier Guidance类似,另外SD有一个遗憾就是没有训练,只玩了采样过程,所以在ControlNet中,再搞清楚原理之后,还会自建数据集进行训练实践(下一篇)!

一、原理介绍

原理部分我们直接拿出Controlnet作者的图

网上对这张图的说明都比较粗略,我们首先要搞明白几个值得含义,参考下面这张图,我们会发现作者在原先的SD之外新增了一个condition的额外输入,这个condition就是我们增加的额外条件,而controlnet的训练是不影响原来的SD的,即原来的SD完全保持独立和不变,影响是通过在SD的Decoder阶段添加condition的编码信息,而训练的部分就说图中的蓝色方块,也就是SD的编码块,如果算力富余的话,controlnet也可以重新训练整个SD。

思路的原理还是比较简单的,以作者的图为例

原先的a输出是,x是输入的二维图像,\thetaneural network block的参数 

添加了Controlnet之后的b输出,Z(;Θz)指的是zero convolution,也就是一个3x3的卷积层,并将权重和偏差都初始化为0 也就是最开始的

Controlnet的损失函数也比较简单,在原先SD的基础之上增加了新的Cf

 二、代码分析

在代码部分,有了SD的基础,我们直接来看Controlnet的forward部分,我们可以看到除了经典的x,timesteps,context三件套输入外,还多了hint

    def forward(self, x, hint, timesteps, context, **kwargs)

这个地方有了前面的基础,我们很容易就能找到调用模型的地方从而知道hint的来源

control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)

 可以看到这里的cond和SD发生了变化,以scribble2img为例,其来源如下,也就是读取输入的图像以后,做一定的图像处理,就可以直接当作c_concat输入进模型了

  1. img = resize_image(HWC3(input_image), image_resolution)
  2. H, W, C = img.shape
  3. detected_map = np.zeros_like(img, dtype=np.uint8)
  4. detected_map[np.min(img, axis=2) < 127] = 255//这步是根据图像颜色取反,可以调出图片看看,比较抽象,最后的输入是黑底,然后白边是原输入的样子
  5. control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
  6. control = torch.stack([control for _ in range(num_samples)], dim=0)
  7. control = einops.rearrange(control, 'b h w c -> b c h w').clone()
  8. cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
  9. un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}

知道了额外条件是如何输入的,我们再返回forward看其是如何被处理的,forward的前两步是对时间步编码,这个在SD中已经熟悉了,接着对将三个条件输入hint、emb、context都输入进了input_hint_block,我们来看一下这是个什么结构。

  1. def forward(self, x, hint, timesteps, context, **kwargs):
  2. t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
  3. emb = self.time_embed(t_emb)
  4. guided_hint = self.input_hint_block(hint, emb, context)

这里回顾一下TimestepEmbedSequential,其主要的作用其实就是分辨输入是否要和emb以及context结合,对于input_hint_block其主要作用是对于hint进行编码,所以输入其实只有hint一个,然后就是一系列的3*3卷积进行特征处理,值得注意的是最后一个卷积层加了一个zero_module,作用就是将该层的参数置0,这样做的目的也很明显,一方面是调整hint的通道数,使其可以被输入进Unet,另一方面也是做特征提取,和SD的输入一样,被置入隐式表达,可以减少显存占用,加快推理速度

  1. self.input_hint_block = TimestepEmbedSequential(
  2. conv_nd(dims, hint_channels, 16, 3, padding=1),
  3. nn.SiLU(),
  4. conv_nd(dims, 16, 16, 3, padding=1),
  5. nn.SiLU(),
  6. conv_nd(dims, 16, 32, 3, padding=1, stride=2),
  7. nn.SiLU(),
  8. conv_nd(dims, 32, 32, 3, padding=1),
  9. nn.SiLU(),
  10. conv_nd(dims, 32, 96, 3, padding=1, stride=2),
  11. nn.SiLU(),
  12. conv_nd(dims, 96, 96, 3, padding=1),
  13. nn.SiLU(),
  14. conv_nd(dims, 96, 256, 3, padding=1, stride=2),
  15. nn.SiLU(),
  16. zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
  17. )
  18. def zero_module(module):
  19. """
  20. Zero out the parameters of a module and return it.
  21. """
  22. for p in module.parameters():
  23. p.detach().zero_()
  24. return module
  25. class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
  26. """
  27. A sequential module that passes timestep embeddings to the children that
  28. support it as an extra input.
  29. """
  30. def forward(self, x, emb, context=None):
  31. for layer in self:
  32. if isinstance(layer, TimestepBlock):
  33. x = layer(x, emb)
  34. elif isinstance(layer, SpatialTransformer):
  35. x = layer(x, context)
  36. else:
  37. x = layer(x)
  38. return x

后面的步骤和SD大同小异,首先module也就是SD中的input_blocks其由resnetblock和spatialtransformer组成,最后输出的就是噪音了,但是在这里噪音还要加速之前的guided_hint,我们来看一下这里新出现的zero_conv,其实就是一个置0的卷积层,而且是1*1的conv2d,这个过程虽然输入了emb和context,但是其都是不作用的,只对h做作用,那么此时的h就是guided_hint+原来SDUnet输出的噪音

  1. outs = []
  2. h = x.type(self.dtype)
  3. for module, zero_conv in zip(self.input_blocks, self.zero_convs):
  4. if guided_hint is not None:
  5. h = module(h, emb, context)
  6. h += guided_hint
  7. guided_hint = None
  8. else:
  9. h = module(h, emb, context)
  10. outs.append(zero_conv(h, emb, context))
  11. self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
  12. def make_zero_conv(self, channels):
  13. return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))

 在降采样并提取特征之后,是middle_block,其结构还是比较清晰的,最后就直接返回outs了。可见其推理过程是相当简单的,相比于SD核心就是增加了新的输入hint和其编码部分input_hint_block

  1. h = self.middle_block(h, emb, context)
  2. outs.append(self.middle_block_out(h, emb, context))
  3. return outs
  4. self.middle_block_out = self.make_zero_conv(ch)
  5. self.middle_block = TimestepEmbedSequential(
  6. ResBlock(
  7. ch,
  8. time_embed_dim,
  9. dropout,
  10. dims=dims,
  11. use_checkpoint=use_checkpoint,
  12. use_scale_shift_norm=use_scale_shift_norm,
  13. ),
  14. AttentionBlock(
  15. ch,
  16. use_checkpoint=use_checkpoint,
  17. num_heads=num_heads,
  18. num_head_channels=dim_head,
  19. use_new_attention_order=use_new_attention_order,
  20. ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
  21. ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
  22. disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
  23. use_checkpoint=use_checkpoint
  24. ),
  25. ResBlock(
  26. ch,
  27. time_embed_dim,
  28. dropout,
  29. dims=dims,
  30. use_checkpoint=use_checkpoint,
  31. use_scale_shift_norm=use_scale_shift_norm,
  32. ),
  33. )

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/830738
推荐阅读
相关标签
  

闽ICP备14008679号