赞
踩
IP-Adapter 源码:https://github.com/tencent-ailab/IP-Adapter
本文就基于 SD1.5 的 IP-Adapter 训练代码 tutorial_train.py 为例,进行代码和结构图的解释。
如上篇 所说,本质就是插入了一条针对图像提示词的输入条件分支:
SD1.5 架构细节强烈推荐这篇博客:Stable Diffusion1.5网络结构-超详细原创-CSDN博客,细节不展开,直接上结论:
我们可以通过对开源项目里给出的预训练权重 ip-adapter_sd15.bin
进行查看其中的权重内容。
ckpt_bin_dict = torch.load('path/to/ip-adapter_sd15.bin', map_location="cpu")
for k, v in ckpt_bin_dict.items():
print(f"Key: {k}, Value type: {type(v)}")
可以得到的以下输出,这些就是训练好的、针对 Image prompt 的 CA 模块,一共 16 个(16 对 to_k_ip 和 to_v_ip)。
<class 'dict'>
Dictionary content:
Key: proj.weight, Shape of value: torch.Size([3072, 1024])
Key: proj.bias, Shape of value: torch.Size([3072])
Key: norm.weight, Shape of value: torch.Size([768])
Key: norm.bias, Shape of value: torch.Size([768])
Dictionary content:
Key: 1.to_k_ip.weight, Shape of value: torch.Size([320, 768])
Key: 1.to_v_ip.weight, Shape of value: torch.Size([320, 768])
Key: 3.to_k_ip.weight, Shape of value: torch.Size([320, 768])
Key: 3.to_v_ip.weight, Shape of value: torch.Size([320, 768])
Key: 5.to_k_ip.weight, Shape of value: torch.Size([640, 768])
Key: 5.to_v_ip.weight, Shape of value: torch.Size([640, 768])
Key: 7.to_k_ip.weight, Shape of value: torch.Size([640, 768])
Key: 7.to_v_ip.weight, Shape of value: torch.Size([640, 768])
Key: 9.to_k_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 9.to_v_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 11.to_k_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 11.to_v_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 13.to_k_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 13.to_v_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 15.to_k_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 15.to_v_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 17.to_k_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 17.to_v_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 19.to_k_ip.weight, Shape of value: torch.Size([640, 768])
Key: 19.to_v_ip.weight, Shape of value: torch.Size([640, 768])
Key: 21.to_k_ip.weight, Shape of value: torch.Size([640, 768])
Key: 21.to_v_ip.weight, Shape of value: torch.Size([640, 768])
Key: 23.to_k_ip.weight, Shape of value: torch.Size([640, 768])
Key: 23.to_v_ip.weight, Shape of value: torch.Size([640, 768])
Key: 25.to_k_ip.weight, Shape of value: torch.Size([320, 768])
Key: 25.to_v_ip.weight, Shape of value: torch.Size([320, 768])
Key: 27.to_k_ip.weight, Shape of value: torch.Size([320, 768])
Key: 27.to_v_ip.weight, Shape of value: torch.Size([320, 768])
Key: 29.to_k_ip.weight, Shape of value: torch.Size([320, 768])
Key: 29.to_v_ip.weight, Shape of value: torch.Size([320, 768])
Key: 31.to_k_ip.weight, Shape of value: torch.Size([1280, 768])
Key: 31.to_v_ip.weight, Shape of value: torch.Size([1280, 768])
通过对比 /path/to/IP-Adapter/ip_adapter/attention_processor.py
中两个类的不同,可以知道本质就是在原来 CA 的基础上,为 image prompt 增加了一个 k 和 v,同时并且共享原有的 q。
与原文《IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models》中的公式(5)完全一致。
其中 IPAttnProcessor2_0 关键代码有两个部分
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
#print(self.attn_map.shape)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
以上就是本篇全部内容,本文通过结构图和相关代码片段介绍了 IP-Adapter 训练代码的核心部分,下篇则介绍其推理代码。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。