赞
踩
下图为《Scaling Rectified Flow Transformers for High-Resolution Image Synthesis》 (ICML 2024 )中的 SD3 架构图。
下面流程图只对正向提示词进行了梳理,负向提示词的流程并无差异。
本文分析的源代码为 diffusers 包中的 SD3 pipeline (位置在/path/to/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
),文本处理部分主要为 其中 __call__()
函数调用的 self.encode_prompt()
函数,主要涉及了 3 个 text encoder 以及对应的 3 个 tokenizer。
其输入输出如下:
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
输入:
具体而言是在 encode_prompt
函数中,通过两次 _get_clip_prompt_embeds
和 _get_t5_prompt_embeds
来调用 3 个 Text Encoder。
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
clip_model_index=0,
)
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
prompt=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
clip_model_index=1,
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
t5_prompt_embed = self._get_t5_prompt_embeds(
prompt=prompt_3,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
text_encoder
(CLIP L/141) 或者 text_encoder_2
(OpenCLIP bigG/142)。text_encoder
和 text_encoder_2
采用的类一致,所以二者的区别主要是模型权重以及 config 不同。...
def __init__(...
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer_2: CLIPTokenizer,
...
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
clip_model_index: int = 0,
):
device = device or self._execution_device
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
tokenizer = clip_tokenizers[clip_model_index]
text_encoder = clip_text_encoders[clip_model_index]
在下载的 SD3 模型权重文件中,/path/to/stable-diffusion-3-medium-diffusers
可以找到 text_encoder
和 text_encoder_2
子目录,对比其中的 config(下图中左边为 text_encoder
,右边为 text_encoder_2
),可以知道二者更具体的不同之处:
text_encoder_2
(OpenCLIP bigG/14) 确实更加 big。prompt
,得到输出为不同的两对 prompt_embed, pooled_prompt_embed
;prompt_2_embed, pooled_prompt_2_embed
。pooled text representation
retains only coarse-grained
information about the text input 3, the network also requires information from the sequence representation
c
t
x
t
c_{txt}
ctxt.prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
clip_model_index=0,
)
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
prompt=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
clip_model_index=1,
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
...
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
T5EncoderModel 的调用则更简洁一点,输入同样是 prompt,并且只有一个输出。
def __init__(...
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
...
t5_prompt_embed = self._get_t5_prompt_embeds(
prompt=prompt_3,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
# 实际为 clip_prompt_embeds = torch.nn.functional.pad(
# clip_prompt_embeds, (0, 4096-2048)
#),即在后面 2048 个维度上 pad 全 0.
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)
# 在序列长度的维度(-2)上 cat 到一起,得到 77+256 = 333 的长度
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
强烈安利另外一位博主的文章:
Learning transferable visual models from natural language supervision, 2021. ↩︎
Reproducible scaling laws for contrastive language-image learning. In 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2023. doi: 10.1109/cvpr52729.2023.00276. URL http://dx.doi.org/10.1109/CVPR52729.2 023.00276. ↩︎
Sdxl: Improving latent diffusion models for high-resolution image synthesis, 2023. ↩︎
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。