赞
踩
在学习SD后,我个人猜测ControlNet和SD中的Classifier Guidance类似,另外SD有一个遗憾就是没有训练,只玩了采样过程,所以在ControlNet中,再搞清楚原理之后,还会自建数据集进行训练实践(下一篇)!
原理部分我们直接拿出Controlnet作者的图
网上对这张图的说明都比较粗略,我们首先要搞明白几个值得含义,参考下面这张图,我们会发现作者在原先的SD之外新增了一个condition的额外输入,这个condition就是我们增加的额外条件,而controlnet的训练是不影响原来的SD的,即原来的SD完全保持独立和不变,影响是通过在SD的Decoder阶段添加condition的编码信息,而训练的部分就说图中的蓝色方块,也就是SD的编码块,如果算力富余的话,controlnet也可以重新训练整个SD。
思路的原理还是比较简单的,以作者的图为例
原先的a输出是,x是输入的二维图像,是neural network block的参数
添加了Controlnet之后的b输出,Z(;Θz)指的是zero convolution,也就是一个3x3的卷积层,并将权重和偏差都初始化为0 也就是最开始的
Controlnet的损失函数也比较简单,在原先SD的基础之上增加了新的Cf
在代码部分,有了SD的基础,我们直接来看Controlnet的forward部分,我们可以看到除了经典的x,timesteps,context三件套输入外,还多了hint
def forward(self, x, hint, timesteps, context, **kwargs)
这个地方有了前面的基础,我们很容易就能找到调用模型的地方从而知道hint的来源
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
可以看到这里的cond和SD发生了变化,以scribble2img为例,其来源如下,也就是读取输入的图像以后,做一定的图像处理,就可以直接当作c_concat输入进模型了
- img = resize_image(HWC3(input_image), image_resolution)
- H, W, C = img.shape
-
- detected_map = np.zeros_like(img, dtype=np.uint8)
- detected_map[np.min(img, axis=2) < 127] = 255//这步是根据图像颜色取反,可以调出图片看看,比较抽象,最后的输入是黑底,然后白边是原输入的样子
-
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
- control = torch.stack([control for _ in range(num_samples)], dim=0)
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
-
- cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
- un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
知道了额外条件是如何输入的,我们再返回forward看其是如何被处理的,forward的前两步是对时间步编码,这个在SD中已经熟悉了,接着对将三个条件输入hint、emb、context都输入进了input_hint_block,我们来看一下这是个什么结构。
- def forward(self, x, hint, timesteps, context, **kwargs):
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
- emb = self.time_embed(t_emb)
-
- guided_hint = self.input_hint_block(hint, emb, context)
这里回顾一下TimestepEmbedSequential,其主要的作用其实就是分辨输入是否要和emb以及context结合,对于input_hint_block其主要作用是对于hint进行编码,所以输入其实只有hint一个,然后就是一系列的3*3卷积进行特征处理,值得注意的是最后一个卷积层加了一个zero_module,作用就是将该层的参数置0,这样做的目的也很明显,一方面是调整hint的通道数,使其可以被输入进Unet,另一方面也是做特征提取,和SD的输入一样,被置入隐式表达,可以减少显存占用,加快推理速度
- self.input_hint_block = TimestepEmbedSequential(
- conv_nd(dims, hint_channels, 16, 3, padding=1),
- nn.SiLU(),
- conv_nd(dims, 16, 16, 3, padding=1),
- nn.SiLU(),
- conv_nd(dims, 16, 32, 3, padding=1, stride=2),
- nn.SiLU(),
- conv_nd(dims, 32, 32, 3, padding=1),
- nn.SiLU(),
- conv_nd(dims, 32, 96, 3, padding=1, stride=2),
- nn.SiLU(),
- conv_nd(dims, 96, 96, 3, padding=1),
- nn.SiLU(),
- conv_nd(dims, 96, 256, 3, padding=1, stride=2),
- nn.SiLU(),
- zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
- )
-
- def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
- """
- A sequential module that passes timestep embeddings to the children that
- support it as an extra input.
- """
-
- def forward(self, x, emb, context=None):
- for layer in self:
- if isinstance(layer, TimestepBlock):
- x = layer(x, emb)
- elif isinstance(layer, SpatialTransformer):
- x = layer(x, context)
- else:
- x = layer(x)
- return x
后面的步骤和SD大同小异,首先module也就是SD中的input_blocks其由resnetblock和spatialtransformer组成,最后输出的就是噪音了,但是在这里噪音还要加速之前的guided_hint,我们来看一下这里新出现的zero_conv,其实就是一个置0的卷积层,而且是1*1的conv2d,这个过程虽然输入了emb和context,但是其都是不作用的,只对h做作用,那么此时的h就是guided_hint+原来SDUnet输出的噪音
- outs = []
-
- h = x.type(self.dtype)
- for module, zero_conv in zip(self.input_blocks, self.zero_convs):
- if guided_hint is not None:
- h = module(h, emb, context)
- h += guided_hint
- guided_hint = None
- else:
- h = module(h, emb, context)
- outs.append(zero_conv(h, emb, context))
-
-
- self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
-
- def make_zero_conv(self, channels):
- return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
在降采样并提取特征之后,是middle_block,其结构还是比较清晰的,最后就直接返回outs了。可见其推理过程是相当简单的,相比于SD核心就是增加了新的输入hint和其编码部分input_hint_block
- h = self.middle_block(h, emb, context)
- outs.append(self.middle_block_out(h, emb, context))
- return outs
-
- self.middle_block_out = self.make_zero_conv(ch)
-
- self.middle_block = TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
- disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
- use_checkpoint=use_checkpoint
- ),
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- )
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。