赞
踩
reference:
https://zhuanlan.zhihu.com/p/642354007
https://zhuanlan.zhihu.com/p/677234407
https://zhuanlan.zhihu.com/p/632809634
https://proceedings.neurips.cc/paper/2020/hash/4c5bcfec8584af0d967f1ab10179ca4b-Abstract.html
根据
x
t
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1 - \alpha_t}\epsilon
xt=αt
xt−1+1−αt
ϵ
可以推出
x
t
x_t
xt 根据
x
0
x_0
x0 的高斯采样分布为:
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
‾
t
x
0
,
1
−
α
‾
t
I
)
q(x_t|x_0) = \mathcal{N}(x_t;{\overline{\alpha}_t}x_0 ,1 - \overline{\alpha}_t I)
q(xt∣x0)=N(xt;αtx0,1−αtI)
重采样形式为:
x
t
=
α
‾
t
x
0
+
1
−
α
‾
t
ϵ
t
x_t = \sqrt{\overline{\alpha}_t}x_0 + \sqrt{1 - \overline{\alpha}_t }\epsilon_t
xt=αt
x0+1−αt
ϵt
因此在训练过程中可以得到任意一个
x
t
x_t
xt 来计算l2误差损失。
一开始优化的目标是负log似然:
−
l
o
g
P
θ
(
x
0
)
-logP_\theta(x_0)
−logPθ(x0)
然而又因为
推理过程:
因此只需要优化
L
t
−
1
L_{t-1}
Lt−1就好,
L
T
L_T
LT 和
L
0
L_0
L0是定值。
让两个高斯分布的KL散度最小,就是让他们的均值和方差最小。而方差
σ
t
\sigma_t
σt 在DDPM中假设为定值
β
t
\beta_t
βt所以不优化。
均值部分为
推导可以有:
def ldm_text_to_image(image_shape, text, ddim_steps = 20, eta = 0): ddim_scheduler = DDIMScheduler() vae = VAE() unet = UNet() zt = randn(image_shape) T = 1000 timesteps = ddim_scheduler.get_timesteps(T, ddim_steps) # [1000, 950, 900, ...] text_encoder = CLIP() c = text_encoder.encode(text) for t = timesteps: eps = unet(zt, t, c) std = ddim_scheduler.get_std(t, eta) zt = ddim_scheduler.get_xt_prev(zt, t, eps, std) xt = vae.decoder.decode(zt) return xt
在Unet 中,每一个attention模块都对应一个AttentionProcessor类和实例,通过 Unset 中的 attn_processors 字典维护。
它的key 是attention 的位置或者说是网络模块的名字,在修改时需要修改attn_processors中我们想要修改部分的AttentionProcessor
例如,
attn_procs = {} unet_sd = unet.state_dict() for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: layer_name = name.split(".processor")[0] weights = { "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) attn_procs[name].load_state_dict(weights) unet.set_attn_processor(attn_procs)
下面是原始的AttentionProcessor,可以修改其中的逻辑,或者添加新的可以学习的变量。
例如IP-Adapter 中就加入了新的image 的cross-attention。
但是注意,_call_()中的变量不能更改。想要传入新的变量,可以cancate 到encoder_hidden_states 上。
class AttnProcessor(nn.Module): r""" Default processor for performing attention-related computations. """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。