当前位置:   article > 正文

Diffuser库Stable_diffusion pipeline代码逐行解析(喂饭级别)_diffusionpipeline

diffusionpipeline



代码位置:diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py at 0d927c75427d12210de2ff3eb6fff1c4ebd2e130 · huggingface/diffusers (github.com)

60-71

 在sd的代码实现中,noise会有第二步的调整,参考论文https://arxiv.org/pdf/2305.08891.pdf

  1. def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
  2. """
  3. Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
  4. Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
  5. """
  6. """
  7. 函数目的是对cfg出来的noise_pred再调整
  8. 参数:
  9. noise_pred_text 是由Unet预测出的noise_pred,按第0维平分的后一半
  10. 具体:noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
  11. noise_cfg: Unet计算出的noise_pred再经过guidance_scale的组合
  12. 具体:noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
  13. guidance_rescale: cgf权重参数
  14. """
  15. # 计算noise_pred_text 除第0维外所有维度的std
  16. std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
  17. # 计算noise_cfg 除第0维外所有维度的std
  18. std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
  19. # rescale the results from guidance (fixes overexposure)
  20. # 调整noise_cfg,权重为std_text / std_cfg
  21. noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
  22. # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
  23. # 将新算出的noise_cfg与原noise_cfg重组
  24. noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
  25. return noise_cfg

74-115

获取时间步和推断的步数

  1. def retrieve_timesteps(
  2. scheduler, # 调度器,用于获取时间步
  3. num_inference_steps: Optional[int] = None, # 推断的步数,有则timesteps为None
  4. device: Optional[Union[str, torch.device]] = None,
  5. timesteps: Optional[List[int]] = None, # 自定的时间步,有则推断步数需为None,无则使用默认的
  6. **kwargs,
  7. ):
  8. """
  9. Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
  10. custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
  11. Args:
  12. scheduler (`SchedulerMixin`):
  13. The scheduler to get timesteps from.
  14. num_inference_steps (`int`):
  15. The number of diffusion steps used when generating samples with a pre-trained model. If used,
  16. `timesteps` must be `None`.
  17. device (`str` or `torch.device`, *optional*):
  18. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
  19. timesteps (`List[int]`, *optional*):
  20. Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
  21. timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
  22. must be `None`.
  23. Returns:
  24. `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
  25. second element is the number of inference steps.
  26. """
  27. # 如果有自定的时间步,则要根据此得出要推断的步数
  28. if timesteps is not None:
  29. # 先检查调度器能否接受自定timesteps,即有没有timesteps的参数名
  30. # inspect.signature 返回函数输入参数的键值对
  31. # inspect.signature(scheduler.set_timesteps).parameters.keys() 返回set_timesteps函数中的参数名
  32. accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
  33. # 无则报错
  34. if not accepts_timesteps:
  35. raise ValueError(
  36. f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
  37. f" timestep schedules. Please check whether you are using the correct scheduler."
  38. )
  39. # 有则按自定义timesteps设置调度器,然后得出推断步数
  40. scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
  41. timesteps = scheduler.timesteps
  42. num_inference_steps = len(timesteps)
  43. # 如果没有自定的时间步,则按推断步数得出默认时间步
  44. else:
  45. scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
  46. timesteps = scheduler.timesteps
  47. return timesteps, num_inference_steps

118-250

pipeline类的定义以及__Init__函数,需要判断unet版本和unet输入的size

  1. class StableDiffusionPipeline(
  2. DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
  3. ):
  4. r"""
  5. Pipeline for text-to-image generation using Stable Diffusion.
  6. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
  7. implemented for all pipelines (downloading, saving, running on a particular device, etc.).
  8. The pipeline also inherits the following loading methods:
  9. - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
  10. - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
  11. - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
  12. - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
  13. - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
  14. Args:
  15. vae ([`AutoencoderKL`]):
  16. Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
  17. text_encoder ([`~transformers.CLIPTextModel`]):
  18. Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
  19. tokenizer ([`~transformers.CLIPTokenizer`]):
  20. A `CLIPTokenizer` to tokenize text.
  21. unet ([`UNet2DConditionModel`]):
  22. A `UNet2DConditionModel` to denoise the encoded image latents.
  23. scheduler ([`SchedulerMixin`]):
  24. A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
  25. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
  26. safety_checker ([`StableDiffusionSafetyChecker`]):
  27. Classification module that estimates whether generated images could be considered offensive or harmful.
  28. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
  29. about a model's potential harms.
  30. feature_extractor ([`~transformers.CLIPImageProcessor`]):
  31. A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
  32. """
  33. model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
  34. _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
  35. _exclude_from_cpu_offload = ["safety_checker"]
  36. _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
  37. def __init__(
  38. self,
  39. vae: AutoencoderKL, # 用作对图片做encoder和decoder
  40. text_encoder: CLIPTextModel, # 对文本提示做token
  41. tokenizer: CLIPTokenizer, # 对token做embed
  42. unet: UNet2DConditionModel,
  43. scheduler: KarrasDiffusionSchedulers, # 调度器
  44. safety_checker: StableDiffusionSafetyChecker, # 安全检查
  45. feature_extractor: CLIPImageProcessor, # 特征提取,对输出图片作用,然后输进safety_checker
  46. image_encoder: CLIPVisionModelWithProjection = None, # 对参考图做embed
  47. requires_safety_checker: bool = True,
  48. ):
  49. super().__init__()
  50. # 若scheduler有steps_offset属性,且steps_offset不为1时提出告示,并将steps_offset设为1
  51. if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
  52. deprecation_message = (
  53. f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
  54. f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
  55. "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
  56. " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
  57. " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
  58. " file"
  59. )
  60. deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
  61. new_config = dict(scheduler.config)
  62. new_config["steps_offset"] = 1
  63. scheduler._internal_dict = FrozenDict(new_config)
  64. # 若schedular有clip_sample属性,提出警告,并设置为False
  65. if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
  66. deprecation_message = (
  67. f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
  68. " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
  69. " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
  70. " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
  71. " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
  72. )
  73. deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
  74. new_config = dict(scheduler.config)
  75. new_config["clip_sample"] = False
  76. scheduler._internal_dict = FrozenDict(new_config)
  77. # 若需要安全检查,但是没有安全检查函数,提出警告
  78. if safety_checker is None and requires_safety_checker:
  79. logger.warning(
  80. f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
  81. " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
  82. " results in services or applications open to the public. Both the diffusers team and Hugging Face"
  83. " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
  84. " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
  85. " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
  86. )
  87. # 若有安全检查,但是没有feature_extractor,提出警告,因为安全检查的输入是由feature_extractor
  88. if safety_checker is not None and feature_extractor is None:
  89. raise ValueError(
  90. "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
  91. " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
  92. )
  93. # 检查Unet版本,通过Unet配置文件中是否有_diffusers_version属性,需要版本小于0.9
  94. # unet为0.6.0
  95. is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
  96. version.parse(unet.config._diffusers_version).base_version
  97. ) < version.parse("0.9.0.dev0")
  98. # 检查unet输入size是不是小于64
  99. is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
  100. # 版本大于0.9且输入size小于64,提出警告,并设置unet输入size为64
  101. if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
  102. deprecation_message = (
  103. "The configuration file of the unet has set the default `sample_size` to smaller than"
  104. " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
  105. " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
  106. " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
  107. " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
  108. " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
  109. " in the config might lead to incorrect results in future versions. If you have downloaded this"
  110. " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
  111. " the `unet/config.json` file"
  112. )
  113. deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
  114. new_config = dict(unet.config)
  115. new_config["sample_size"] = 64
  116. unet._internal_dict = FrozenDict(new_config)

251-279

分片VAE编码和分块VAE编码,如果图片太大,则分patch输入至VAE中

  1. def enable_vae_slicing(self):
  2. # 开启分片VAE编码,即按顺序一次只对一张图片编码。
  3. # 能用于大batch时候减少VRAM,但是速度会减慢
  4. r"""
  5. Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
  6. compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
  7. """
  8. self.vae.enable_slicing()
  9. def disable_vae_slicing(self):
  10. # 关闭分片VAE编码
  11. r"""
  12. Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
  13. computing decoding in one step.
  14. """
  15. self.vae.disable_slicing()
  16. def enable_vae_tiling(self):
  17. # 开启分片VAE编码,即将图像分成重叠的块,对每个块解码,最后将输出混合生成最终图像。
  18. # 能用于处理大尺寸图像
  19. r"""
  20. Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
  21. compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
  22. processing larger images.
  23. """
  24. self.vae.enable_tiling()
  25. def disable_vae_tiling(self):
  26. # 关闭分片VAE编码
  27. r"""
  28. Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
  29. computing decoding in one step.
  30. """
  31. self.vae.disable_tiling()

280-491

对文本提示编码,旧版本已经不使用了。用新版本的。

  1. def encode_prompt(
  2. self,
  3. prompt, # 文本提示
  4. device,
  5. num_images_per_prompt, # 每个提示生成的文本数量
  6. do_classifier_free_guidance, # cfg
  7. negative_prompt=None, # 负向词
  8. prompt_embeds: Optional[torch.FloatTensor] = None, # 预设的文本embed
  9. negative_prompt_embeds: Optional[torch.FloatTensor] = None, # 预设的负向文本embed
  10. lora_scale: Optional[float] = None, # lora_scale
  11. clip_skip: Optional[int] = None, # clip输出要跳过的层数
  12. ):
  13. r"""
  14. Encodes the prompt into text encoder hidden states.
  15. Args:
  16. prompt (`str` or `List[str]`, *optional*):
  17. prompt to be encoded
  18. device: (`torch.device`):
  19. torch device
  20. num_images_per_prompt (`int`):
  21. number of images that should be generated per prompt
  22. do_classifier_free_guidance (`bool`):
  23. whether to use classifier free guidance or not
  24. negative_prompt (`str` or `List[str]`, *optional*):
  25. The prompt or prompts not to guide the image generation. If not defined, one has to pass
  26. `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
  27. less than `1`).
  28. prompt_embeds (`torch.FloatTensor`, *optional*):
  29. Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
  30. provided, text embeddings will be generated from `prompt` input argument.
  31. negative_prompt_embeds (`torch.FloatTensor`, *optional*):
  32. Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
  33. weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
  34. argument.
  35. lora_scale (`float`, *optional*):
  36. A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
  37. clip_skip (`int`, *optional*):
  38. Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
  39. the output of the pre-final layer will be used for computing the prompt embeddings.
  40. """
  41. # set lora scale so that monkey patched LoRA
  42. # function of text encoder can correctly access it
  43. # 设置text_encoder中的lora_scale
  44. if lora_scale is not None and isinstance(self, LoraLoaderMixin):
  45. self._lora_scale = lora_scale
  46. # dynamically adjust the LoRA scale
  47. if not USE_PEFT_BACKEND:
  48. adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
  49. else:
  50. scale_lora_layers(self.text_encoder, lora_scale)
  51. # 根据prompt数量设置batch_size
  52. if prompt is not None and isinstance(prompt, str):
  53. batch_size = 1
  54. elif prompt is not None and isinstance(prompt, list):
  55. batch_size = len(prompt)
  56. else:
  57. batch_size = prompt_embeds.shape[0]
  58. # 若没有预设的prompt_embed
  59. if prompt_embeds is None:
  60. # textual inversion: procecss multi-vector tokens if necessary
  61. # 若当前对象是TextualInversionLoaderMixin的实例,
  62. # 那么会调用maybe_convert_prompt函数对多向量标记进行处理,以进行文本反转。
  63. if isinstance(self, TextualInversionLoaderMixin):
  64. prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
  65. # 文本编码,输出是字典形式,有input_ids,attention_mask
  66. # input_ids是token, attention_mask指出哪些是输入文本,哪些是填充
  67. text_inputs = self.tokenizer(
  68. prompt,
  69. padding="max_length",
  70. max_length=self.tokenizer.model_max_length,
  71. truncation=True,
  72. return_tensors="pt",
  73. )
  74. # 获取token,即input_ids
  75. text_input_ids = text_inputs.input_ids
  76. # 使用longest,获取没截断的文本tokens
  77. untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
  78. # 如果没截断的文本tokens大于tokens,获取removed_text,即被截断的文本。
  79. if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
  80. text_input_ids, untruncated_ids
  81. ):
  82. # 对被截断的文本token做decoder
  83. removed_text = self.tokenizer.batch_decode(
  84. untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
  85. )
  86. logger.warning(
  87. "The following part of your input was truncated because CLIP can only handle sequences up to"
  88. f" {self.tokenizer.model_max_length} tokens: {removed_text}"
  89. )
  90. # 若文本编码器有use_attention_mask参数,且use_attention_mask不为空,获取attn_mask
  91. if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
  92. attention_mask = text_inputs.attention_mask.to(device)
  93. else:
  94. attention_mask = None
  95. # 获取text_embed,没有clip_skip,直接输出
  96. if clip_skip is None:
  97. prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
  98. prompt_embeds = prompt_embeds[0]
  99. # 有clip_skip
  100. else:
  101. prompt_embeds = self.text_encoder(
  102. text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
  103. )
  104. # Access the `hidden_states` first, that contains a tuple of
  105. # all the hidden states from the encoder layers. Then index into
  106. # the tuple to access the hidden states from the desired layer.
  107. # keys中有['last_hidden_state', 'pooler_output', 'hidden_states']
  108. # 选择hidden_states,-(clip_skip+1)表示输出倒数第几层
  109. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
  110. # We also need to apply the final LayerNorm here to not mess with the
  111. # representations. The `last_hidden_states` that we typically use for
  112. # obtaining the final prompt representations passes through the LayerNorm
  113. # layer.
  114. # 对输出的hidden_states做最后一层的layer_norm,使分布一样
  115. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
  116. # 根据模型提取数据格式,将text_embed数据格式统一化
  117. if self.text_encoder is not None:
  118. prompt_embeds_dtype = self.text_encoder.dtype
  119. elif self.unet is not None:
  120. prompt_embeds_dtype = self.unet.dtype
  121. else:
  122. prompt_embeds_dtype = prompt_embeds.dtype
  123. # 设置数据格式
  124. prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
  125. # 获取batch_size,序列长度
  126. bs_embed, seq_len, _ = prompt_embeds.shape
  127. # duplicate text embeddings for each generation per prompt, using mps friendly method
  128. # 按序列长度维度重复num_images_per_prompt次
  129. prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
  130. # 将[bs_embed, num_images_per_prompt*seq_len, channel] -> [bs_embed*num_images_per_prompt, seq_len, channel]
  131. prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
  132. # get unconditional embeddings for classifier free guidance
  133. # 获取无条件tokens
  134. # 1.若要cfg而没有预设的负向embed
  135. # 2.若负向prompt为空,用空文本代替。否则用负向prompt
  136. if do_classifier_free_guidance and negative_prompt_embeds is None:
  137. uncond_tokens: List[str]
  138. if negative_prompt is None:
  139. uncond_tokens = [""] * batch_size
  140. # 判断prompt格式和neg_prompt格式
  141. elif prompt is not None and type(prompt) is not type(negative_prompt):
  142. raise TypeError(
  143. f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
  144. f" {type(prompt)}."
  145. )
  146. # 根据neg_prompt格式设定无条件tokens
  147. elif isinstance(negative_prompt, str):
  148. uncond_tokens = [negative_prompt]
  149. elif batch_size != len(negative_prompt):
  150. raise ValueError(
  151. f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
  152. f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
  153. " the batch size of `prompt`."
  154. )
  155. else:
  156. uncond_tokens = negative_prompt
  157. # textual inversion: procecss multi-vector tokens if necessary
  158. # tokens反推文本。
  159. if isinstance(self, TextualInversionLoaderMixin):
  160. uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
  161. # 重新提取无条件tokens
  162. max_length = prompt_embeds.shape[1]
  163. uncond_input = self.tokenizer(
  164. uncond_tokens,
  165. padding="max_length",
  166. max_length=max_length,
  167. truncation=True,
  168. return_tensors="pt",
  169. )
  170. # 提取无条件tokens的attention_mask
  171. if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
  172. attention_mask = uncond_input.attention_mask.to(device)
  173. else:
  174. attention_mask = None
  175. # 用无条件tokens的embed做负向embed
  176. negative_prompt_embeds = self.text_encoder(
  177. uncond_input.input_ids.to(device),
  178. attention_mask=attention_mask,
  179. )
  180. negative_prompt_embeds = negative_prompt_embeds[0]
  181. # 若有cfg,复制负向embed,和正向embed一样。
  182. if do_classifier_free_guidance:
  183. # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
  184. seq_len = negative_prompt_embeds.shape[1]
  185. negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
  186. negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
  187. negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
  188. if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
  189. # Retrieve the original scale by scaling back the LoRA layers
  190. unscale_lora_layers(self.text_encoder, lora_scale)
  191. return prompt_embeds, negative_prompt_embeds

493-626

图像编码函数

  1. def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
  2. """
  3. 对图片做encoder,输出image_embeds和ucond_image_embeds
  4. Args:
  5. image:参考图(ip_adapter_image)
  6. num_images_per_prompt:每个prompt生成的图片数量
  7. output_hidden_states:是否输出隐藏状态,具体为倒数第二层(hidden_states[-2])
  8. = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
  9. """
  10. # 获取图像编码器的数据类型,以便统一后续image数据
  11. dtype = next(self.image_encoder.parameters()).dtype
  12. # 判断输入image是不是Tensor,不是则用feature_extractor(CLIPImmageProcessor)函数转Tensor
  13. if not isinstance(image, torch.Tensor):
  14. image = self.feature_extractor(image, return_tensors="pt").pixel_values
  15. # 将image存入device,修改数据类型
  16. image = image.to(device=device, dtype=dtype)
  17. # 输出隐藏状态
  18. if output_hidden_states:
  19. # 获取倒数第二层的隐藏状态输出
  20. image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
  21. # 在第0维重复num_images_per_prompt次,用作输出num_images_per_prompt张图
  22. image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
  23. # 以零矩阵作为无条件输入图像编码器
  24. uncond_image_enc_hidden_states = self.image_encoder(
  25. torch.zeros_like(image), output_hidden_states=True
  26. ).hidden_states[-2]
  27. uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
  28. num_images_per_prompt, dim=0
  29. )
  30. return image_enc_hidden_states, uncond_image_enc_hidden_states
  31. # 不输出隐藏状态
  32. else:
  33. # 直接输出image_embed
  34. image_embeds = self.image_encoder(image).image_embeds
  35. # 重复堆叠num_images_per_prompt次
  36. image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
  37. # 直接以零矩阵作为无条件编码
  38. uncond_image_embeds = torch.zeros_like(image_embeds)
  39. return image_embeds, uncond_image_embeds

对最终输出的图片作安全检查,输出image和has_nsfw_concept。

  1. def run_safety_checker(self, image, device, dtype):
  2. # 若不需要安全检查,则设置has_nsfw_concept为None
  3. if self.safety_checker is None:
  4. has_nsfw_concept = None
  5. # 需要安全检查
  6. # 先将image转为pil格式,然后再用CLIP提取特征,最后将image和特征输入安全检查器
  7. else:
  8. # 若image是Tensor,用后处理将image转为pil
  9. if torch.is_tensor(image):
  10. feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
  11. # 若image不是Tensor,用numpy_to_pil转为pil
  12. else:
  13. feature_extractor_input = self.image_processor.numpy_to_pil(image)
  14. # 用feature_extractor(CLIPImageProcessor)提取特征
  15. safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
  16. # 安全检查
  17. image, has_nsfw_concept = self.safety_checker(
  18. images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
  19. )
  20. return image, has_nsfw_concept

准备调度器额外的关键参数,eta和generator。

  1. def prepare_extra_step_kwargs(self, generator, eta):
  2. # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
  3. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
  4. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
  5. # and should be between [0, 1]
  6. # 检查调度器函数是否有'eta'关键字,只有DDIM要用
  7. accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
  8. extra_step_kwargs = {}
  9. if accepts_eta:
  10. extra_step_kwargs["eta"] = eta
  11. # 检查调度器函数是否有'generator'关键字
  12. # check if the scheduler accepts generator
  13. accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
  14. if accepts_generator:
  15. extra_step_kwargs["generator"] = generator
  16. return extra_step_kwargs

检查输入是否正确

  1. def check_inputs(
  2. self,
  3. prompt, # 提示词
  4. height, # 图像高
  5. width, # 图像宽
  6. callback_steps, # 表示在进行模型推断过程中,每经过callback_steps个步骤后,会执行一次回调操作。回调操作可以是对模型状态的检查、记录或其他自定义操作,用于监控和控制模型的行为。
  7. negative_prompt=None, # 负向词
  8. prompt_embeds=None, # 指定提示词embed
  9. negative_prompt_embeds=None, # 指定负向词embed
  10. callback_on_step_end_tensor_inputs=None, # 指定在每个步骤结束时要传递给回调函数的张量输入。
  11. ):
  12. # 长宽需8的倍数
  13. if height % 8 != 0 or width % 8 != 0:
  14. raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
  15. # callback_steps需正整数
  16. if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
  17. raise ValueError(
  18. f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
  19. f" {type(callback_steps)}."
  20. )
  21. # callback_on_step_end_tensor_inputs必须在self._callback_tensor_inputs中
  22. # _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
  23. if callback_on_step_end_tensor_inputs is not None and not all(
  24. k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
  25. ):
  26. raise ValueError(
  27. f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
  28. )
  29. # 提示词和提示词embed必须一个非空一个为空,且提示词必须是str或者list格式
  30. if prompt is not None and prompt_embeds is not None:
  31. raise ValueError(
  32. f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
  33. " only forward one of the two."
  34. )
  35. elif prompt is None and prompt_embeds is None:
  36. raise ValueError(
  37. "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
  38. )
  39. elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
  40. raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
  41. # 负向提示词和负向提示词embed不能同时有
  42. if negative_prompt is not None and negative_prompt_embeds is not None:
  43. raise ValueError(
  44. f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
  45. f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
  46. )
  47. # 提示词的shape和负向提示词的shape要一样
  48. if prompt_embeds is not None and negative_prompt_embeds is not None:
  49. if prompt_embeds.shape != negative_prompt_embeds.shape:
  50. raise ValueError(
  51. "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
  52. f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
  53. f" {negative_prompt_embeds.shape}."

准备潜变量,用于StableDiffusion模型的推断过程。

  1. def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
  2. """
  3. Args:
  4. generator:torch的随机数生成器
  5. """
  6. # 定义潜变量的shape, self.vae_scale_factor为VAE下采样倍数
  7. shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
  8. # 如果generator不是list或者长度不等于batch_size,报错
  9. if isinstance(generator, list) and len(generator) != batch_size:
  10. raise ValueError(
  11. f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
  12. f" size of {batch_size}. Make sure the batch size matches the length of the generators."
  13. )
  14. # 若潜变量为空,则随机生成高斯噪声。不为空,则直接用
  15. if latents is None:
  16. latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
  17. else:
  18. latents = latents.to(device)
  19. # scale the initial noise by the standard deviation required by the scheduler
  20. # 将初始化噪声乘以scheduler.init_noise_sigma,以缩放噪声的标准差。
  21. latents = latents * self.scheduler.init_noise_sigma
  22. return latents

628-776

这段代码定义了一个enable_freeu方法,用于启用Unet的FreeU机制。FreeU机制是一个用于增强去噪过程的方法.

该方法接受四个参数: