赞
踩
IP-Adapter 源码:https://github.com/tencent-ailab/IP-Adapter
本文就基于 SD1.5 的 IP-Adapter 训练代码 tutorial_train.py 为例,进行代码和结构图的解释。
如上图所示,插入了图中的最上面一条分支(图像输入条件分支):
在论文中也提到,具体分别是:
先简单看下模型的训练时的输入,即 /path/IP-Adapter/tutorial_train.py 中 main() 函数内的 dataloader 部分,下面代码通过调用 MyDataset 类来实现了 train_dataloader 的构建。
# dataloader
train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
对于实际训练使用的数据则为从 train_dataloader 中取的:
vae.encoder
得到 latents后torch.randn_like(latents)
按照 latents
张量的形状生成一个随机的噪声张量 noise
。image_encoder
得到 image_embeds
图像特征encoder_hidden_states
。 for step, batch in enumerate(train_dataloader):
load_data_time = time.perf_counter() - begin
with torch.no_grad():
latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
with torch.no_grad():
image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds
image_embeds_ = []
for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
if drop_image_embed == 1:
image_embeds_.append(torch.zeros_like(image_embed))
else:
image_embeds_.append(image_embed)
image_embeds = torch.stack(image_embeds_)
with torch.no_grad():
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] # pooled_prompt_embeds?
以 SD1.5 + IP-Adapter 的训练代码为例:
下方代码为 /path/IP-Adapter/tutorial_train.py 中 main() 函数内,调用了定义好的 ImageProjModel 类
#ip-adapter
image_proj_model = ImageProjModel(
cross_attention_dim=unet.config.cross_attention_dim,
clip_embeddings_dim=image_encoder.config.projection_dim,
clip_extra_context_tokens=4,
)
下方代码为 /path/IP-Adapter/ip_adapter/ip_adapter.py 被调用的 ImageProjModel 类,在构造函数 __init__
中可以看到有前文提到的 Linear 和 LayerNorm。
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
本文详解了IP-Adapter 训练源码中的输入部分,下篇则详解核心部分,针对图像输入的 Cross-Attention。
Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016 ↩︎
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。