当前位置:   article > 正文

如何让AI帮你干活-娱乐(3)_prompt_embeds

prompt_embeds
背景

今天的话题会偏代码技巧一些,对于以前没有接触过代码的朋友或者接触代码开发经验较少的朋友会有些吃力。

上篇文章介绍了如何广视角的生成相对稳定的视频。昨天的实现相对简单,主要用的是UI界面来做生成。但是生成的效果其实也显而易见,不算太好。人物连贯性和动作连贯性是不够的。原因如下:

1.stablediffusion webui的batch 的image2image还无法真正做到image2image

2.控制其实是通过固定的文本prompt+多个controlnet来控制

然后如果希望画面能够稳定,其实image的控制信息和每张图用相对有差异的prompt生成的图片质量连贯性和稳定性会更好(image2image控制整体的风格和内容,controlnet控制细节,prompt可以控制一些内容差异)。

这篇文章不会跟大家分享,稳定图生成的具体细节。而是跟大家分享更重要的,如何用代码来实现webui的功能,如何用代码方式搭建更可控更高效的图生成链路。

内容

1.用代码搭建stablediffusion+controlnet生产流程

2.multi-control net生产流程搭建

3.diffuser没有的功能如何自己实现加入

用代码搭建stablediffusion+controlnet生产流程

安装diffuser包

  1. pip install --upgrade diffusers accelerate transformers
  2. #如果要安装最新diffuser可以执行下面指令
  3. pip install git+https://github.com/huggingface/diffusers
  4. ! pip install controlnet_hinter==0.0.5

搭建stablediffusion+controlnet脚本

  1. from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
  2. from diffusers.utils import load_image
  3. import torch
  4. import numpy as np
  5. from PIL import Image
  6. import cv2
  7. #加载测试图片
  8. original_image = load_image("https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_imgvar/input_image_vermeer.png")
  9. original_image
  10. #设置controlnet+stablefiddufion流水线
  11. controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny",cache_dir='./')
  12. pipe = StableDiffusionControlNetPipeline.from_pretrained(
  13. "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None
  14. ).to('cuda')
  15. pipe.enable_xformers_memory_efficient_attention()
  16. #抽取图片canny边界
  17. canny_edged_image = load_image("https://huggingface.co/takuma104/controlnet_dev/resolve/main/vermeer_canny_edged.png")
  18. canny_edged_image
  19. #canny controlnet做图片生成
  20. generator = torch.Generator(device="cpu").manual_seed(3)
  21. image = pipe(prompt="best quality, extremely detailed",
  22. negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
  23. image=canny_edged_image,
  24. num_inference_steps=30,
  25. generator=generator).images[0]
  26. image
  27. #设置canny编辑过滤阀值
  28. control_image = np.array(original_image)
  29. low_threshold = 10 # default=100
  30. high_threshold = 150 # default=200
  31. control_image = cv2.Canny(control_image, low_threshold, high_threshold)
  32. control_image = control_image[:, :, None]
  33. control_image = np.concatenate([control_image, control_image, control_image], axis=2)
  34. control_image = Image.fromarray(control_image)
  35. control_image
  36. #canny controlnet做图片生成
  37. generator = torch.Generator(device="cpu").manual_seed(3)
  38. image = pipe(prompt="best quality, extremely detailed",
  39. negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
  40. image=control_image,
  41. num_inference_steps=30,
  42. generator=generator).images[0]
  43. image
  44. #pose的controlnet流水线
  45. controlnet = None
  46. pipe = None
  47. controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose",cache_dir= './')
  48. pipe = StableDiffusionControlNetPipeline.from_pretrained(
  49. "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None
  50. ).to('cuda')
  51. pipe.enable_xformers_memory_efficient_attention()
  52. #pose的图片
  53. pose_image = load_image('https://huggingface.co/takuma104/controlnet_dev/resolve/main/pose.png')
  54. pose_image
  55. #pose流水线生产图
  56. generator = torch.Generator(device="cpu").manual_seed(0)
  57. image = pipe(prompt="best quality, extremely detailed, football, a boy",
  58. negative_prompt="lowres, bad anatomy, worst quality, low quality",
  59. image=pose_image,
  60. generator=generator,
  61. num_inference_steps=30).images[0]
  62. image
  63. #用controlnet_hinter抽取pose控制特征
  64. control_image = controlnet_hinter.hint_openpose(original_image)
  65. control_image
  66. #pose流水线生产图
  67. generator = torch.Generator(device="cpu").manual_seed(0)
  68. image = pipe(prompt="best quality, extremely detailed",
  69. negative_prompt="lowres, bad anatomy, worst quality, low quality",
  70. image=control_image,
  71. generator=generator,
  72. num_inference_steps=30).images[0]
  73. image
  74. #深度图
  75. controlnet = None
  76. pipe = None
  77. controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth")
  78. pipe = StableDiffusionControlNetPipeline.from_pretrained(
  79. "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None
  80. ).to('cuda')
  81. pipe.enable_xformers_memory_efficient_attention()
  82. control_image = controlnet_hinter.hint_depth(original_image)
  83. control_image
  84. generator = torch.Generator(device="cpu").manual_seed(0)
  85. image = pipe(prompt="best quality, extremely detailed",
  86. negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
  87. image=control_image,
  88. generator=generator,
  89. num_inference_steps=30).images[0]
  90. image
  91. #轮廓图
  92. controlnet = None
  93. pipe = None
  94. controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-scribble")
  95. pipe = StableDiffusionControlNetPipeline.from_pretrained(
  96. "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None).to('cuda')
  97. pipe.enable_xformers_memory_efficient_attention()
  98. scribble_image = load_image('https://github.com/lllyasviel/ControlNet/raw/main/test_imgs/user_1.png')
  99. scribble_image
  100. generator = torch.Generator(device="cpu").manual_seed(1)
  101. image = pipe(prompt="a turtle, best quality, extremely detailed",
  102. negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
  103. image=scribble_image,
  104. generator=generator,
  105. num_inference_steps=30).images[0]
  106. image
  107. #Segmentation
  108. controlnet = None
  109. pipe = None
  110. controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg")
  111. pipe = StableDiffusionControlNetPipeline.from_pretrained(
  112. "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None).to('cuda')
  113. pipe.enable_xformers_memory_efficient_attention()
  114. control_image = controlnet_hinter.hint_segmentation(original_image)
  115. control_image
  116. generator = torch.Generator(device="cpu").manual_seed(0)
  117. image = pipe(prompt="best quality, extremely detailed",
  118. negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
  119. image=control_image,
  120. generator=generator,
  121. num_inference_steps=30).images[0]
  122. image
  123. #Hough 建筑、大场景里用的多
  124. controlnet = None
  125. pipe = None
  126. controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-mlsd")
  127. pipe = StableDiffusionControlNetPipeline.from_pretrained(
  128. "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None).to('cuda')
  129. pipe.enable_xformers_memory_efficient_attention()
  130. control_image = controlnet_hinter.hint_hough(original_image)
  131. control_image
  132. generator = torch.Generator(device="cpu").manual_seed(2)
  133. image = pipe(prompt="best quality, extremely detailed",
  134. negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
  135. image=control_image,
  136. generator=generator,
  137. num_inference_steps=30).images[0]
  138. image
multi-control net生产流程搭建

diffuser包里面现在还没实现多control net控制图生成,需要使用multi-controlnet可以用以下代码。

  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. from typing import Any, Callable, Dict, List, Optional, Union, Tuple
  16. import numpy as np
  17. import PIL.Image
  18. import torch
  19. from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
  20. from diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel
  21. from diffusers.schedulers import KarrasDiffusionSchedulers
  22. from diffusers.utils import (
  23. PIL_INTERPOLATION,
  24. is_accelerate_available,
  25. is_accelerate_version,
  26. logging,
  27. randn_tensor,
  28. replace_example_docstring,
  29. )
  30. from diffusers.pipeline_utils import DiffusionPipeline
  31. from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
  32. from diffusers.models.controlnet import ControlNetOutput
  33. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  34. class ControlNetProcessor(object):
  35. def __init__(
  36. self,
  37. controlnet: ControlNetModel,
  38. image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
  39. conditioning_scale: float = 1.0,
  40. ):
  41. self.controlnet = controlnet
  42. self.image = image
  43. self.conditioning_scale = conditioning_scale
  44. def _default_height_width(self, height, width, image):
  45. if isinstance(image, list):
  46. image = image[0]
  47. if height is None:
  48. if isinstance(image, PIL.Image.Image):
  49. height = image.height
  50. elif isinstance(image, torch.Tensor):
  51. height = image.shape[3]
  52. height = (height // 8) * 8 # round down to nearest multiple of 8
  53. if width is None:
  54. if isinstance(image, PIL.Image.Image):
  55. width = image.width
  56. elif isinstance(image, torch.Tensor):
  57. width = image.shape[2]
  58. width = (width // 8) * 8 # round down to nearest multiple of 8
  59. return height, width
  60. def default_height_width(self, height, width):
  61. return self._default_height_width(height, width, self.image)
  62. def _prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype):
  63. if not isinstance(image, torch.Tensor):
  64. if isinstance(image, PIL.Image.Image):
  65. image = [image]
  66. if isinstance(image[0], PIL.Image.Image):
  67. image = [
  68. np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image
  69. ]
  70. image = np.concatenate(image, axis=0)
  71. image = np.array(image).astype(np.float32) / 255.0
  72. image = image.transpose(0, 3, 1, 2)
  73. image = torch.from_numpy(image)
  74. elif isinstance(image[0], torch.Tensor):
  75. image = torch.cat(image, dim=0)
  76. image_batch_size = image.shape[0]
  77. if image_batch_size == 1:
  78. repeat_by = batch_size
  79. else:
  80. # image batch size is the same as prompt batch size
  81. repeat_by = num_images_per_prompt
  82. image = image.repeat_interleave(repeat_by, dim=0)
  83. image = image.to(device=device, dtype=dtype)
  84. return image
  85. def _check_inputs(self, image, prompt, prompt_embeds):
  86. image_is_pil = isinstance(image, PIL.Image.Image)
  87. image_is_tensor = isinstance(image, torch.Tensor)
  88. image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
  89. image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
  90. if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
  91. raise TypeError(
  92. "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
  93. )
  94. if image_is_pil:
  95. image_batch_size = 1
  96. elif image_is_tensor:
  97. image_batch_size = image.shape[0]
  98. elif image_is_pil_list:
  99. image_batch_size = len(image)
  100. elif image_is_tensor_list:
  101. image_batch_size = len(image)
  102. if prompt is not None and isinstance(prompt, str):
  103. prompt_batch_size = 1
  104. elif prompt is not None and isinstance(prompt, list):
  105. prompt_batch_size = len(prompt)
  106. elif prompt_embeds is not None:
  107. prompt_batch_size = prompt_embeds.shape[0]
  108. if image_batch_size != 1 and image_batch_size != prompt_batch_size:
  109. raise ValueError(
  110. f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
  111. )
  112. def check_inputs(self, prompt, prompt_embeds):
  113. self._check_inputs(self.image, prompt, prompt_embeds)
  114. def prepare_image(self, width, height, batch_size, num_images_per_prompt, device, do_classifier_free_guidance):
  115. self.image = self._prepare_image(
  116. self.image, width, height, batch_size, num_images_per_prompt, device, self.controlnet.dtype
  117. )
  118. if do_classifier_free_guidance:
  119. self.image = torch.cat([self.image] * 2)
  120. def __call__(
  121. self,
  122. sample: torch.FloatTensor,
  123. timestep: Union[torch.Tensor, float, int],
  124. encoder_hidden_states: torch.Tensor,
  125. class_labels: Optional[torch.Tensor] = None,
  126. timestep_cond: Optional[torch.Tensor] = None,
  127. attention_mask: Optional[torch.Tensor] = None,
  128. cross_attention_kwargs: Optional[Dict[str, Any]] = None,
  129. return_dict: bool = True,
  130. ) -> Tuple:
  131. down_block_res_samples, mid_block_res_sample = self.controlnet(
  132. sample=sample,
  133. controlnet_cond=self.image,
  134. timestep=timestep,
  135. encoder_hidden_states=encoder_hidden_states,
  136. class_labels=class_labels,
  137. timestep_cond=timestep_cond,
  138. attention_mask=attention_mask,
  139. cross_attention_kwargs=cross_attention_kwargs,
  140. return_dict=False,
  141. )
  142. down_block_res_samples = [
  143. down_block_res_sample * self.conditioning_scale for down_block_res_sample in down_block_res_samples
  144. ]
  145. mid_block_res_sample *= self.conditioning_scale
  146. return (down_block_res_samples, mid_block_res_sample)
  147. EXAMPLE_DOC_STRING = """
  148. Examples:
  149. ```py
  150. >>> # !pip install opencv-python transformers accelerate
  151. >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
  152. >>> from diffusers.utils import load_image
  153. >>> import numpy as np
  154. >>> import torch
  155. >>> import cv2
  156. >>> from PIL import Image
  157. >>> # download an image
  158. >>> image = load_image(
  159. ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
  160. ... )
  161. >>> image = np.array(image)
  162. >>> # get canny image
  163. >>> image = cv2.Canny(image, 100, 200)
  164. >>> image = image[:, :, None]
  165. >>> image = np.concatenate([image, image, image], axis=2)
  166. >>> canny_image = Image.fromarray(image)
  167. >>> # load control net and stable diffusion v1-5
  168. >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
  169. >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
  170. ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
  171. ... )
  172. >>> # speed up diffusion process with faster scheduler and memory optimization
  173. >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
  174. >>> # remove following line if xformers is not installed
  175. >>> pipe.enable_xformers_memory_efficient_attention()
  176. >>> pipe.enable_model_cpu_offload()
  177. >>> # generate image
  178. >>> generator = torch.manual_seed(0)
  179. >>> image = pipe(
  180. ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
  181. ... ).images[0]
  182. ```
  183. """
  184. class StableDiffusionMultiControlNetPipeline(DiffusionPipeline):
  185. r"""
  186. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
  187. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
  188. library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
  189. Args:
  190. vae ([`AutoencoderKL`]):
  191. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
  192. text_encoder ([`CLIPTextModel`]):
  193. Frozen text-encoder. Stable Diffusion uses the text portion of
  194. [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
  195. the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
  196. tokenizer (`CLIPTokenizer`):
  197. Tokenizer of class
  198. [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
  199. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
  200. scheduler ([`SchedulerMixin`]):
  201. A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
  202. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
  203. safety_checker ([`StableDiffusionSafetyChecker`]):
  204. Classification module that estimates whether generated images could be considered offensive or harmful.
  205. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
  206. feature_extractor ([`CLIPFeatureExtractor`]):
  207. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
  208. """
  209. _optional_components = ["safety_checker", "feature_extractor"]
  210. def __init__(
  211. self,
  212. vae: AutoencoderKL,
  213. text_encoder: CLIPTextModel,
  214. tokenizer: CLIPTokenizer,
  215. unet: UNet2DConditionModel,
  216. scheduler: KarrasDiffusionSchedulers,
  217. safety_checker: StableDiffusionSafetyChecker,
  218. feature_extractor: CLIPFeatureExtractor,
  219. requires_safety_checker: bool = True,
  220. ):
  221. super().__init__()
  222. if safety_checker is None and requires_safety_checker:
  223. logger.warning(
  224. f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
  225. " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
  226. " results in services or applications open to the public. Both the diffusers team and Hugging Face"
  227. " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
  228. " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
  229. " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
  230. )
  231. if safety_checker is not None and feature_extractor is None:
  232. raise ValueError(
  233. "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
  234. " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
  235. )
  236. self.register_modules(
  237. vae=vae,
  238. text_encoder=text_encoder,
  239. tokenizer=tokenizer,
  240. unet=unet,
  241. scheduler=scheduler,
  242. safety_checker=safety_checker,
  243. feature_extractor=feature_extractor,
  244. )
  245. self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
  246. self.register_to_config(requires_safety_checker=requires_safety_checker)
  247. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
  248. def enable_vae_slicing(self):
  249. r"""
  250. Enable sliced VAE decoding.
  251. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
  252. steps. This is useful to save some memory and allow larger batch sizes.
  253. """
  254. self.vae.enable_slicing()
  255. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
  256. def disable_vae_slicing(self):
  257. r"""
  258. Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
  259. computing decoding in one step.
  260. """
  261. self.vae.disable_slicing()
  262. def enable_sequential_cpu_offload(self, gpu_id=0):
  263. r"""
  264. Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
  265. text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
  266. `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
  267. Note that offloading happens on a submodule basis. Memory savings are higher than with
  268. `enable_model_cpu_offload`, but performance is lower.
  269. """
  270. if is_accelerate_available():
  271. from accelerate import cpu_offload
  272. else:
  273. raise ImportError("Please install accelerate via `pip install accelerate`")
  274. device = torch.device(f"cuda:{gpu_id}")
  275. for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
  276. cpu_offload(cpu_offloaded_model, device)
  277. if self.safety_checker is not None:
  278. cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
  279. def enable_model_cpu_offload(self, gpu_id=0):
  280. r"""
  281. Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
  282. to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
  283. method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
  284. `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
  285. """
  286. if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
  287. from accelerate import cpu_offload_with_hook
  288. else:
  289. raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
  290. device = torch.device(f"cuda:{gpu_id}")
  291. hook = None
  292. for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
  293. _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
  294. if self.safety_checker is not None:
  295. # the safety checker can offload the vae again
  296. _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
  297. # control net hook has be manually offloaded as it alternates with unet
  298. # cpu_offload_with_hook(self.controlnet, device)
  299. # We'll offload the last model manually.
  300. self.final_offload_hook = hook
  301. @property
  302. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
  303. def _execution_device(self):
  304. r"""
  305. Returns the device on which the pipeline's models will be executed. After calling
  306. `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
  307. hooks.
  308. """
  309. if not hasattr(self.unet, "_hf_hook"):
  310. return self.device
  311. for module in self.unet.modules():
  312. if (
  313. hasattr(module, "_hf_hook")
  314. and hasattr(module._hf_hook, "execution_device")
  315. and module._hf_hook.execution_device is not None
  316. ):
  317. return torch.device(module._hf_hook.execution_device)
  318. return self.device
  319. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
  320. def _encode_prompt(
  321. self,
  322. prompt,
  323. device,
  324. num_images_per_prompt,
  325. do_classifier_free_guidance,
  326. negative_prompt=None,
  327. prompt_embeds: Optional[torch.FloatTensor] = None,
  328. negative_prompt_embeds: Optional[torch.FloatTensor] = None,
  329. ):
  330. r"""
  331. Encodes the prompt into text encoder hidden states.
  332. Args:
  333. prompt (`str` or `List[str]`, *optional*):
  334. prompt to be encoded
  335. device: (`torch.device`):
  336. torch device
  337. num_images_per_prompt (`int`):
  338. number of images that should be generated per prompt
  339. do_classifier_free_guidance (`bool`):
  340. whether to use classifier free guidance or not
  341. negative_prompt (`str` or `List[str]`, *optional*):
  342. The prompt or prompts not to guide the image generation. If not defined, one has to pass
  343. `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
  344. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
  345. prompt_embeds (`torch.FloatTensor`, *optional*):
  346. Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
  347. provided, text embeddings will be generated from `prompt` input argument.
  348. negative_prompt_embeds (`torch.FloatTensor`, *optional*):
  349. Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
  350. weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
  351. argument.
  352. """
  353. if prompt is not None and isinstance(prompt, str):
  354. batch_size = 1
  355. elif prompt is not None and isinstance(prompt, list):
  356. batch_size = len(prompt)
  357. else:
  358. batch_size = prompt_embeds.shape[0]
  359. if prompt_embeds is None:
  360. text_inputs = self.tokenizer(
  361. prompt,
  362. padding="max_length",
  363. max_length=self.tokenizer.model_max_length,
  364. truncation=True,
  365. return_tensors="pt",
  366. )
  367. text_input_ids = text_inputs.input_ids
  368. untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
  369. if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
  370. text_input_ids, untruncated_ids
  371. ):
  372. removed_text = self.tokenizer.batch_decode(
  373. untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
  374. )
  375. logger.warning(
  376. "The following part of your input was truncated because CLIP can only handle sequences up to"
  377. f" {self.tokenizer.model_max_length} tokens: {removed_text}"
  378. )
  379. if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
  380. attention_mask = text_inputs.attention_mask.to(device)
  381. else:
  382. attention_mask = None
  383. prompt_embeds = self.text_encoder(
  384. text_input_ids.to(device),
  385. attention_mask=attention_mask,
  386. )
  387. prompt_embeds = prompt_embeds[0]
  388. prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
  389. bs_embed, seq_len, _ = prompt_embeds.shape
  390. # duplicate text embeddings for each generation per prompt, using mps friendly method
  391. prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
  392. prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
  393. # get unconditional embeddings for classifier free guidance
  394. if do_classifier_free_guidance and negative_prompt_embeds is None:
  395. uncond_tokens: List[str]
  396. if negative_prompt is None:
  397. uncond_tokens = [""] * batch_size
  398. elif type(prompt) is not type(negative_prompt):
  399. raise TypeError(
  400. f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
  401. f" {type(prompt)}."
  402. )
  403. elif isinstance(negative_prompt, str):
  404. uncond_tokens = [negative_prompt]
  405. elif batch_size != len(negative_prompt):
  406. raise ValueError(
  407. f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
  408. f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
  409. " the batch size of `prompt`."
  410. )
  411. else:
  412. uncond_tokens = negative_prompt
  413. max_length = prompt_embeds.shape[1]
  414. uncond_input = self.tokenizer(
  415. uncond_tokens,
  416. padding="max_length",
  417. max_length=max_length,
  418. truncation=True,
  419. return_tensors="pt",
  420. )
  421. if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
  422. attention_mask = uncond_input.attention_mask.to(device)
  423. else:
  424. attention_mask = None
  425. negative_prompt_embeds = self.text_encoder(
  426. uncond_input.input_ids.to(device),
  427. attention_mask=attention_mask,
  428. )
  429. negative_prompt_embeds = negative_prompt_embeds[0]
  430. if do_classifier_free_guidance:
  431. # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
  432. seq_len = negative_prompt_embeds.shape[1]
  433. negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
  434. negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
  435. negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
  436. # For classifier free guidance, we need to do two forward passes.
  437. # Here we concatenate the unconditional and text embeddings into a single batch
  438. # to avoid doing two forward passes
  439. prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
  440. return prompt_embeds
  441. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
  442. def run_safety_checker(self, image, device, dtype):
  443. if self.safety_checker is not None:
  444. safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
  445. image, has_nsfw_concept = self.safety_checker(
  446. images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
  447. )
  448. else:
  449. has_nsfw_concept = None
  450. return image, has_nsfw_concept
  451. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
  452. def decode_latents(self, latents):
  453. latents = 1 / self.vae.config.scaling_factor * latents
  454. image = self.vae.decode(latents).sample
  455. image = (image / 2 + 0.5).clamp(0, 1)
  456. # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
  457. image = image.cpu().permute(0, 2, 3, 1).float().numpy()
  458. return image
  459. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
  460. def prepare_extra_step_kwargs(self, generator, eta):
  461. # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
  462. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
  463. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
  464. # and should be between [0, 1]
  465. accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
  466. extra_step_kwargs = {}
  467. if accepts_eta:
  468. extra_step_kwargs["eta"] = eta
  469. # check if the scheduler accepts generator
  470. accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
  471. if accepts_generator:
  472. extra_step_kwargs["generator"] = generator
  473. return extra_step_kwargs
  474. def check_inputs(
  475. self,
  476. prompt,
  477. height,
  478. width,
  479. callback_steps,
  480. negative_prompt=None,
  481. prompt_embeds=None,
  482. negative_prompt_embeds=None,
  483. ):
  484. if height % 8 != 0 or width % 8 != 0:
  485. raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
  486. if (callback_steps is None) or (
  487. callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
  488. ):
  489. raise ValueError(
  490. f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
  491. f" {type(callback_steps)}."
  492. )
  493. if prompt is not None and prompt_embeds is not None:
  494. raise ValueError(
  495. f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
  496. " only forward one of the two."
  497. )
  498. elif prompt is None and prompt_embeds is None:
  499. raise ValueError(
  500. "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
  501. )
  502. elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
  503. raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
  504. if negative_prompt is not None and negative_prompt_embeds is not None:
  505. raise ValueError(
  506. f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
  507. f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
  508. )
  509. if prompt_embeds is not None and negative_prompt_embeds is not None:
  510. if prompt_embeds.shape != negative_prompt_embeds.shape:
  511. raise ValueError(
  512. "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
  513. f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
  514. f" {negative_prompt_embeds.shape}."
  515. )
  516. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
  517. def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
  518. shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
  519. if isinstance(generator, list) and len(generator) != batch_size:
  520. raise ValueError(
  521. f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
  522. f" size of {batch_size}. Make sure the batch size matches the length of the generators."
  523. )
  524. if latents is None:
  525. latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
  526. else:
  527. latents = latents.to(device)
  528. # scale the initial noise by the standard deviation required by the scheduler
  529. latents = latents * self.scheduler.init_noise_sigma
  530. return latents
  531. @torch.no_grad()
  532. @replace_example_docstring(EXAMPLE_DOC_STRING)
  533. def __call__(
  534. self,
  535. processors: List[ControlNetProcessor],
  536. prompt: Union[str, List[str]] = None,
  537. height: Optional[int] = None,
  538. width: Optional[int] = None,
  539. num_inference_steps: int = 50,
  540. guidance_scale: float = 7.5,
  541. negative_prompt: Optional[Union[str, List[str]]] = None,
  542. num_images_per_prompt: Optional[int] = 1,
  543. eta: float = 0.0,
  544. generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
  545. latents: Optional[torch.FloatTensor] = None,
  546. prompt_embeds: Optional[torch.FloatTensor] = None,
  547. negative_prompt_embeds: Optional[torch.FloatTensor] = None,
  548. output_type: Optional[str] = "pil",
  549. return_dict: bool = True,
  550. callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
  551. callback_steps: int = 1,
  552. cross_attention_kwargs: Optional[Dict[str, Any]] = None,
  553. ):
  554. r"""
  555. Function invoked when calling the pipeline for generation.
  556. Args:
  557. prompt (`str` or `List[str]`, *optional*):
  558. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
  559. instead.
  560. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
  561. The height in pixels of the generated image.
  562. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
  563. The width in pixels of the generated image.
  564. num_inference_steps (`int`, *optional*, defaults to 50):
  565. The number of denoising steps. More denoising steps usually lead to a higher quality image at the
  566. expense of slower inference.
  567. guidance_scale (`float`, *optional*, defaults to 7.5):
  568. Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
  569. `guidance_scale` is defined as `w` of equation 2. of [Imagen
  570. Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
  571. 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
  572. usually at the expense of lower image quality.
  573. negative_prompt (`str` or `List[str]`, *optional*):
  574. The prompt or prompts not to guide the image generation. If not defined, one has to pass
  575. `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
  576. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
  577. num_images_per_prompt (`int`, *optional*, defaults to 1):
  578. The number of images to generate per prompt.
  579. eta (`float`, *optional*, defaults to 0.0):
  580. Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
  581. [`schedulers.DDIMScheduler`], will be ignored for others.
  582. generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
  583. One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
  584. to make generation deterministic.
  585. latents (`torch.FloatTensor`, *optional*):
  586. Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
  587. generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
  588. tensor will ge generated by sampling using the supplied random `generator`.
  589. prompt_embeds (`torch.FloatTensor`, *optional*):
  590. Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
  591. provided, text embeddings will be generated from `prompt` input argument.
  592. negative_prompt_embeds (`torch.FloatTensor`, *optional*):
  593. Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
  594. weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
  595. argument.
  596. output_type (`str`, *optional*, defaults to `"pil"`):
  597. The output format of the generate image. Choose between
  598. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
  599. return_dict (`bool`, *optional*, defaults to `True`):
  600. Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
  601. plain tuple.
  602. callback (`Callable`, *optional*):
  603. A function that will be called every `callback_steps` steps during inference. The function will be
  604. called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
  605. callback_steps (`int`, *optional*, defaults to 1):
  606. The frequency at which the `callback` function will be called. If not specified, the callback will be
  607. called at every step.
  608. cross_attention_kwargs (`dict`, *optional*):
  609. A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
  610. `self.processor` in
  611. [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
  612. Examples:
  613. Returns:
  614. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
  615. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
  616. When returning a tuple, the first element is a list with the generated images, and the second element is a
  617. list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
  618. (nsfw) content, according to the `safety_checker`.
  619. """
  620. # 0. Default height and width to unet
  621. height, width = processors[0].default_height_width(height, width)
  622. # 1. Check inputs. Raise error if not correct
  623. self.check_inputs(
  624. prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
  625. )
  626. for processor in processors:
  627. processor.check_inputs(prompt, prompt_embeds)
  628. # 2. Define call parameters
  629. if prompt is not None and isinstance(prompt, str):
  630. batch_size = 1
  631. elif prompt is not None and isinstance(prompt, list):
  632. batch_size = len(prompt)
  633. else:
  634. batch_size = prompt_embeds.shape[0]
  635. device = self._execution_device
  636. # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
  637. # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
  638. # corresponds to doing no classifier free guidance.
  639. do_classifier_free_guidance = guidance_scale > 1.0
  640. # 3. Encode input prompt
  641. prompt_embeds = self._encode_prompt(
  642. prompt,
  643. device,
  644. num_images_per_prompt,
  645. do_classifier_free_guidance,
  646. negative_prompt,
  647. prompt_embeds=prompt_embeds,
  648. negative_prompt_embeds=negative_prompt_embeds,
  649. )
  650. # 4. Prepare image
  651. for processor in processors:
  652. processor.prepare_image(
  653. width=width,
  654. height=height,
  655. batch_size=batch_size * num_images_per_prompt,
  656. num_images_per_prompt=num_images_per_prompt,
  657. device=device,
  658. do_classifier_free_guidance=do_classifier_free_guidance,
  659. )
  660. # 5. Prepare timesteps
  661. self.scheduler.set_timesteps(num_inference_steps, device=device)
  662. timesteps = self.scheduler.timesteps
  663. # 6. Prepare latent variables
  664. num_channels_latents = self.unet.in_channels
  665. latents = self.prepare_latents(
  666. batch_size * num_images_per_prompt,
  667. num_channels_latents,
  668. height,
  669. width,
  670. prompt_embeds.dtype,
  671. device,
  672. generator,
  673. latents,
  674. )
  675. # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
  676. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
  677. # 8. Denoising loop
  678. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
  679. with self.progress_bar(total=num_inference_steps) as progress_bar:
  680. for i, t in enumerate(timesteps):
  681. # expand the latents if we are doing classifier free guidance
  682. latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
  683. latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
  684. # controlnet inference
  685. for i, processor in enumerate(processors):
  686. down_samples, mid_sample = processor(
  687. latent_model_input,
  688. t,
  689. encoder_hidden_states=prompt_embeds,
  690. return_dict=False,
  691. )
  692. if i == 0:
  693. down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
  694. else:
  695. down_block_res_samples = [
  696. d_prev + d_curr for d_prev, d_curr in zip(down_block_res_samples, down_samples)
  697. ]
  698. mid_block_res_sample = mid_block_res_sample + mid_sample
  699. # predict the noise residual
  700. noise_pred = self.unet(
  701. latent_model_input,
  702. t,
  703. encoder_hidden_states=prompt_embeds,
  704. cross_attention_kwargs=cross_attention_kwargs,
  705. down_block_additional_residuals=down_block_res_samples,
  706. mid_block_additional_residual=mid_block_res_sample,
  707. ).sample
  708. # perform guidance
  709. if do_classifier_free_guidance:
  710. noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
  711. noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
  712. # compute the previous noisy sample x_t -> x_t-1
  713. latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
  714. # call the callback, if provided
  715. if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
  716. progress_bar.update()
  717. if callback is not None and i % callback_steps == 0:
  718. callback(i, t, latents)
  719. # If we do sequential model offloading, let's offload unet and controlnet
  720. # manually for max memory savings
  721. if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
  722. self.unet.to("cpu")
  723. torch.cuda.empty_cache()
  724. if output_type == "latent":
  725. image = latents
  726. has_nsfw_concept = None
  727. elif output_type == "pil":
  728. # 8. Post-processing
  729. image = self.decode_latents(latents)
  730. # 9. Run safety checker
  731. image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
  732. # 10. Convert to PIL
  733. image = self.numpy_to_pil(image)
  734. else:
  735. # 8. Post-processing
  736. image = self.decode_latents(latents)
  737. # 9. Run safety checker
  738. image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
  739. # Offload last model to CPU
  740. if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
  741. self.final_offload_hook.offload()
  742. if not return_dict:
  743. return (image, has_nsfw_concept)
  744. return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
  745. # demo & simple test
  746. def main():
  747. from diffusers.utils import load_image
  748. pipe = StableDiffusionMultiControlNetPipeline.from_pretrained(
  749. "./model", safety_checker=None, torch_dtype=torch.float16
  750. ).to("cuda")
  751. pipe.enable_xformers_memory_efficient_attention()
  752. controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny",cache_dir='./controlmodel', torch_dtype=torch.float16).to(
  753. "cuda"
  754. )
  755. controlnet_pose = ControlNetModel.from_pretrained(
  756. "lllyasviel/sd-controlnet-openpose",cache_dir='./controlmodel', torch_dtype=torch.float16
  757. ).to("cuda")
  758. canny_left = load_image("https://huggingface.co/takuma104/controlnet_dev/resolve/main/vermeer_left.png")
  759. canny_right = load_image("https://huggingface.co/takuma104/controlnet_dev/resolve/main/vermeer_right.png")
  760. pose_right = load_image("https://huggingface.co/takuma104/controlnet_dev/resolve/main/pose_right.png")
  761. image = pipe(
  762. prompt="best quality, extremely detailed",
  763. negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
  764. processors=[
  765. ControlNetProcessor(controlnet_canny, canny_left),
  766. ControlNetProcessor(controlnet_canny, canny_right),
  767. ],
  768. generator=torch.Generator(device="cpu").manual_seed(0),
  769. num_inference_steps=30,
  770. width=512,
  771. height=512,
  772. ).images[0]
  773. image.save("./canny_left_right.png")
  774. image = pipe(
  775. prompt="best quality, extremely detailed",
  776. negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
  777. processors=[
  778. ControlNetProcessor(controlnet_canny, canny_left,0.5),
  779. ControlNetProcessor(controlnet_pose, pose_right,1.6),
  780. ],
  781. generator=torch.Generator(device="cpu").manual_seed(0),
  782. num_inference_steps=30,
  783. width=512,
  784. height=512,
  785. ).images[0]
  786. image.save("./canny_left_pose_right.png")
  787. if __name__ == "__main__":
  788. main()
diffuser没有的功能如何自己实现加入

假设我们要自己实现一个stablediffusion+controlnet+inpaint的功能,该如何实现。这个任务大部分是在生产流程上做串接,所以代码基本可以定位在pipline模块stablediffusion。

假设我们代码实现如下(代码在特定版本有效,这个后面会升级,大家可以不急着用。只是跟大家讲解diffuser的代码框架,以及改动一个模块该如何融入diffuser中供自己使用)

  1. #代码文件名:pipeline_stable_diffusion_controlnet_inpaint.py
  2. # Copyright 2023 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import inspect
  16. from typing import Any, Callable, Dict, List, Optional, Union
  17. import numpy as np
  18. import PIL.Image
  19. import torch
  20. from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
  21. from ...models import AutoencoderKL, UNet2DConditionModel
  22. from ...schedulers import KarrasDiffusionSchedulers
  23. from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring
  24. from ..pipeline_utils import DiffusionPipeline
  25. from . import StableDiffusionPipelineOutput
  26. from .safety_checker import StableDiffusionSafetyChecker
  27. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  28. EXAMPLE_DOC_STRING = """
  29. Examples:
  30. ```py
  31. >>> from diffusers import StableDiffusionControlNetPipeline
  32. >>> from diffusers.utils import load_image
  33. >>> # Canny edged image for control
  34. >>> canny_edged_image = load_image(
  35. ... "https://huggingface.co/takuma104/controlnet_dev/resolve/main/vermeer_canny_edged.png"
  36. ... )
  37. >>> pipe = StableDiffusionControlNetPipeline.from_pretrained("takuma104/control_sd15_canny").to("cuda")
  38. >>> image = pipe(prompt="best quality, extremely detailed", controlnet_hint=canny_edged_image).images[0]
  39. ```
  40. """
  41. def prepare_mask_and_masked_image(image, mask):
  42. """
  43. Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
  44. converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
  45. ``image`` and ``1`` for the ``mask``.
  46. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
  47. binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
  48. Args:
  49. image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
  50. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
  51. ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
  52. mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
  53. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
  54. ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
  55. Raises:
  56. ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
  57. should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
  58. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
  59. (ot the other way around).
  60. Returns:
  61. tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
  62. dimensions: ``batch x channels x height x width``.
  63. """
  64. if isinstance(image, torch.Tensor):
  65. if not isinstance(mask, torch.Tensor):
  66. raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
  67. # Batch single image
  68. if image.ndim == 3:
  69. assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
  70. image = image.unsqueeze(0)
  71. # Batch and add channel dim for single mask
  72. if mask.ndim == 2:
  73. mask = mask.unsqueeze(0).unsqueeze(0)
  74. # Batch single mask or add channel dim
  75. if mask.ndim == 3:
  76. # Single batched mask, no channel dim or single mask not batched but channel dim
  77. if mask.shape[0] == 1:
  78. mask = mask.unsqueeze(0)
  79. # Batched masks no channel dim
  80. else:
  81. mask = mask.unsqueeze(1)
  82. assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
  83. assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
  84. assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
  85. # Check image is in [-1, 1]
  86. if image.min() < -1 or image.max() > 1:
  87. raise ValueError("Image should be in [-1, 1] range")
  88. # Check mask is in [0, 1]
  89. if mask.min() < 0 or mask.max() > 1:
  90. raise ValueError("Mask should be in [0, 1] range")
  91. # Binarize mask
  92. mask[mask < 0.5] = 0
  93. mask[mask >= 0.5] = 1
  94. # Image as float32
  95. image = image.to(dtype=torch.float32)
  96. elif isinstance(mask, torch.Tensor):
  97. raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
  98. else:
  99. # preprocess image
  100. if isinstance(image, (PIL.Image.Image, np.ndarray)):
  101. image = [image]
  102. if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
  103. image = [np.array(i.convert("RGB"))[None, :] for i in image]
  104. image = np.concatenate(image, axis=0)
  105. elif isinstance(image, list) and isinstance(image[0], np.ndarray):
  106. image = np.concatenate([i[None, :] for i in image], axis=0)
  107. image = image.transpose(0, 3, 1, 2)
  108. image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
  109. # preprocess mask
  110. if isinstance(mask, (PIL.Image.Image, np.ndarray)):
  111. mask = [mask]
  112. if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
  113. mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
  114. mask = mask.astype(np.float32) / 255.0
  115. elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
  116. mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
  117. mask[mask < 0.5] = 0
  118. mask[mask >= 0.5] = 1
  119. mask = torch.from_numpy(mask)
  120. masked_image = image * (mask < 0.5)
  121. return mask, masked_image
  122. class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
  123. r"""
  124. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
  125. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
  126. library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
  127. Args:
  128. vae ([`AutoencoderKL`]):
  129. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
  130. text_encoder ([`CLIPTextModel`]):
  131. Frozen text-encoder. Stable Diffusion uses the text portion of
  132. [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
  133. the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
  134. tokenizer (`CLIPTokenizer`):
  135. Tokenizer of class
  136. [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
  137. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
  138. controlnet ([`UNet2DConditionModel`]):
  139. [ControlNet](https://arxiv.org/abs/2302.05543) architecture to generate guidance.
  140. scheduler ([`SchedulerMixin`]):
  141. A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
  142. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
  143. safety_checker ([`StableDiffusionSafetyChecker`]):
  144. Classification module that estimates whether generated images could be considered offensive or harmful.
  145. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
  146. feature_extractor ([`CLIPFeatureExtractor`]):
  147. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
  148. """
  149. def __init__(
  150. self,
  151. vae: AutoencoderKL,
  152. text_encoder: CLIPTextModel,
  153. tokenizer: CLIPTokenizer,
  154. unet: UNet2DConditionModel,
  155. controlnet: UNet2DConditionModel,
  156. scheduler: KarrasDiffusionSchedulers,
  157. safety_checker: StableDiffusionSafetyChecker,
  158. feature_extractor: CLIPFeatureExtractor,
  159. requires_safety_checker: bool = True,
  160. ):
  161. super().__init__()
  162. self.register_modules(
  163. vae=vae,
  164. text_encoder=text_encoder,
  165. tokenizer=tokenizer,
  166. unet=unet,
  167. controlnet=controlnet,
  168. scheduler=scheduler,
  169. safety_checker=safety_checker,
  170. feature_extractor=feature_extractor,
  171. )
  172. self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
  173. self.register_to_config(requires_safety_checker=requires_safety_checker)
  174. def enable_vae_slicing(self):
  175. r"""
  176. Enable sliced VAE decoding.
  177. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
  178. steps. This is useful to save some memory and allow larger batch sizes.
  179. """
  180. self.vae.enable_slicing()
  181. def disable_vae_slicing(self):
  182. r"""
  183. Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
  184. computing decoding in one step.
  185. """
  186. self.vae.disable_slicing()
  187. def enable_sequential_cpu_offload(self, gpu_id=0):
  188. r"""
  189. Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
  190. text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
  191. `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
  192. """
  193. if is_accelerate_available():
  194. from accelerate import cpu_offload
  195. else:
  196. raise ImportError("Please install accelerate via `pip install accelerate`")
  197. device = torch.device(f"cuda:{gpu_id}")
  198. for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
  199. cpu_offload(cpu_offloaded_model, device)
  200. if self.safety_checker is not None:
  201. cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
  202. @property
  203. def _execution_device(self):
  204. r"""
  205. Returns the device on which the pipeline's models will be executed. After calling
  206. `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
  207. hooks.
  208. """
  209. if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
  210. return self.device
  211. for module in self.unet.modules():
  212. if (
  213. hasattr(module, "_hf_hook")
  214. and hasattr(module._hf_hook, "execution_device")
  215. and module._hf_hook.execution_device is not None
  216. ):
  217. return torch.device(module._hf_hook.execution_device)
  218. return self.device
  219. def _encode_prompt(
  220. self,
  221. prompt,
  222. device,
  223. num_images_per_prompt,
  224. do_classifier_free_guidance,
  225. negative_prompt=None,
  226. prompt_embeds: Optional[torch.FloatTensor] = None,
  227. negative_prompt_embeds: Optional[torch.FloatTensor] = None,
  228. ):
  229. r"""
  230. Encodes the prompt into text encoder hidden states.
  231. Args:
  232. prompt (`str` or `List[str]`, *optional*):
  233. prompt to be encoded
  234. device: (`torch.device`):
  235. torch device
  236. num_images_per_prompt (`int`):
  237. number of images that should be generated per prompt
  238. do_classifier_free_guidance (`bool`):
  239. whether to use classifier free guidance or not
  240. negative_prompt (`str` or `List[str]`, *optional*):
  241. The prompt or prompts not to guide the image generation. If not defined, one has to pass
  242. `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
  243. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
  244. prompt_embeds (`torch.FloatTensor`, *optional*):
  245. Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
  246. provided, text embeddings will be generated from `prompt` input argument.
  247. negative_prompt_embeds (`torch.FloatTensor`, *optional*):
  248. Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
  249. weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
  250. argument.
  251. """
  252. if prompt is not None and isinstance(prompt, str):
  253. batch_size = 1
  254. elif prompt is not None and isinstance(prompt, list):
  255. batch_size = len(prompt)
  256. else:
  257. batch_size = prompt_embeds.shape[0]
  258. if prompt_embeds is None:
  259. text_inputs = self.tokenizer(
  260. prompt,
  261. padding="max_length",
  262. max_length=self.tokenizer.model_max_length,
  263. truncation=True,
  264. return_tensors="pt",
  265. )
  266. text_input_ids = text_inputs.input_ids
  267. untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
  268. if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
  269. text_input_ids, untruncated_ids
  270. ):
  271. removed_text = self.tokenizer.batch_decode(
  272. untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
  273. )
  274. logger.warning(
  275. "The following part of your input was truncated because CLIP can only handle sequences up to"
  276. f" {self.tokenizer.model_max_length} tokens: {removed_text}"
  277. )
  278. if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
  279. attention_mask = text_inputs.attention_mask.to(device)
  280. else:
  281. attention_mask = None
  282. prompt_embeds = self.text_encoder(
  283. text_input_ids.to(device),
  284. attention_mask=attention_mask,
  285. )
  286. prompt_embeds = prompt_embeds[0]
  287. prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
  288. bs_embed, seq_len, _ = prompt_embeds.shape
  289. # duplicate text embeddings for each generation per prompt, using mps friendly method
  290. prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
  291. prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
  292. # get unconditional embeddings for classifier free guidance
  293. if do_classifier_free_guidance and negative_prompt_embeds is None:
  294. uncond_tokens: List[str]
  295. if negative_prompt is None:
  296. uncond_tokens = [""] * batch_size
  297. elif type(prompt) is not type(negative_prompt):
  298. raise TypeError(
  299. f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
  300. f" {type(prompt)}."
  301. )
  302. elif isinstance(negative_prompt, str):
  303. uncond_tokens = [negative_prompt]
  304. elif batch_size != len(negative_prompt):
  305. raise ValueError(
  306. f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
  307. f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
  308. " the batch size of `prompt`."
  309. )
  310. else:
  311. uncond_tokens = negative_prompt
  312. max_length = prompt_embeds.shape[1]
  313. uncond_input = self.tokenizer(
  314. uncond_tokens,
  315. padding="max_length",
  316. max_length=max_length,
  317. truncation=True,
  318. return_tensors="pt",
  319. )
  320. if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
  321. attention_mask = uncond_input.attention_mask.to(device)
  322. else:
  323. attention_mask = None
  324. negative_prompt_embeds = self.text_encoder(
  325. uncond_input.input_ids.to(device),
  326. attention_mask=attention_mask,
  327. )
  328. negative_prompt_embeds = negative_prompt_embeds[0]
  329. if do_classifier_free_guidance:
  330. # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
  331. seq_len = negative_prompt_embeds.shape[1]
  332. negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
  333. negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
  334. negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
  335. # For classifier free guidance, we need to do two forward passes.
  336. # Here we concatenate the unconditional and text embeddings into a single batch
  337. # to avoid doing two forward passes
  338. prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
  339. return prompt_embeds
  340. def run_safety_checker(self, image, device, dtype):
  341. if self.safety_checker is not None:
  342. safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
  343. image, has_nsfw_concept = self.safety_checker(
  344. images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
  345. )
  346. else:
  347. has_nsfw_concept = None
  348. return image, has_nsfw_concept
  349. def decode_latents(self, latents):
  350. latents = 1 / self.vae.config.scaling_factor * latents
  351. image = self.vae.decode(latents).sample
  352. image = (image / 2 + 0.5).clamp(0, 1)
  353. # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
  354. image = image.cpu().permute(0, 2, 3, 1).float().numpy()
  355. return image
  356. def prepare_extra_step_kwargs(self, generator, eta):
  357. # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
  358. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
  359. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
  360. # and should be between [0, 1]
  361. accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
  362. extra_step_kwargs = {}
  363. if accepts_eta:
  364. extra_step_kwargs["eta"] = eta
  365. # check if the scheduler accepts generator
  366. accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
  367. if accepts_generator:
  368. extra_step_kwargs["generator"] = generator
  369. return extra_step_kwargs
  370. def decode_latents(self, latents):
  371. latents = 1 / self.vae.config.scaling_factor * latents
  372. image = self.vae.decode(latents).sample
  373. image = (image / 2 + 0.5).clamp(0, 1)
  374. # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
  375. image = image.cpu().permute(0, 2, 3, 1).float().numpy()
  376. return image
  377. def check_inputs(
  378. self,
  379. prompt,
  380. height,
  381. width,
  382. callback_steps,
  383. negative_prompt=None,
  384. prompt_embeds=None,
  385. negative_prompt_embeds=None,
  386. ):
  387. if height % 8 != 0 or width % 8 != 0:
  388. raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
  389. if (callback_steps is None) or (
  390. callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
  391. ):
  392. raise ValueError(
  393. f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
  394. f" {type(callback_steps)}."
  395. )
  396. if prompt is not None and prompt_embeds is not None:
  397. raise ValueError(
  398. f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
  399. " only forward one of the two."
  400. )
  401. elif prompt is None and prompt_embeds is None:
  402. raise ValueError(
  403. "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
  404. )
  405. elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
  406. raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
  407. if negative_prompt is not None and negative_prompt_embeds is not None:
  408. raise ValueError(
  409. f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
  410. f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
  411. )
  412. if prompt_embeds is not None and negative_prompt_embeds is not None:
  413. if prompt_embeds.shape != negative_prompt_embeds.shape:
  414. raise ValueError(
  415. "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
  416. f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
  417. f" {negative_prompt_embeds.shape}."
  418. )
  419. def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
  420. shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
  421. if isinstance(generator, list) and len(generator) != batch_size:
  422. raise ValueError(
  423. f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
  424. f" size of {batch_size}. Make sure the batch size matches the length of the generators."
  425. )
  426. if latents is None:
  427. latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
  428. else:
  429. latents = latents.to(device)
  430. # scale the initial noise by the standard deviation required by the scheduler
  431. latents = latents * self.scheduler.init_noise_sigma
  432. return latents
  433. def controlnet_hint_conversion(self, controlnet_hint, height, width, num_images_per_prompt):
  434. channels = 3
  435. if isinstance(controlnet_hint, torch.Tensor):
  436. # torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt)
  437. shape_chw = (channels, height, width)
  438. shape_bchw = (1, channels, height, width)
  439. shape_nchw = (num_images_per_prompt, channels, height, width)
  440. if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]:
  441. controlnet_hint = controlnet_hint.to(dtype=self.controlnet.dtype, device=self.controlnet.device)
  442. if controlnet_hint.shape != shape_nchw:
  443. controlnet_hint = controlnet_hint.repeat(num_images_per_prompt, 1, 1, 1)
  444. return controlnet_hint
  445. else:
  446. raise ValueError(
  447. f"Acceptble shape of `controlnet_hint` are any of ({channels}, {height}, {width}),"
  448. + f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
  449. + f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
  450. )
  451. elif isinstance(controlnet_hint, np.ndarray):
  452. # np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
  453. # hwc is opencv compatible image format. Color channel must be BGR Format.
  454. if controlnet_hint.shape == (height, width):
  455. controlnet_hint = np.repeat(controlnet_hint[:, :, np.newaxis], channels, axis=2) # hw -> hwc(c==3)
  456. shape_hwc = (height, width, channels)
  457. shape_bhwc = (1, height, width, channels)
  458. shape_nhwc = (num_images_per_prompt, height, width, channels)
  459. if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]:
  460. controlnet_hint = torch.from_numpy(controlnet_hint.copy())
  461. controlnet_hint = controlnet_hint.to(dtype=self.controlnet.dtype, device=self.controlnet.device)
  462. controlnet_hint /= 255.0
  463. if controlnet_hint.shape != shape_nhwc:
  464. controlnet_hint = controlnet_hint.repeat(num_images_per_prompt, 1, 1, 1)
  465. controlnet_hint = controlnet_hint.permute(0, 3, 1, 2) # b h w c -> b c h w
  466. return controlnet_hint
  467. else:
  468. raise ValueError(
  469. f"Acceptble shape of `controlnet_hint` are any of ({width}, {channels}), "
  470. + f"({height}, {width}, {channels}), "
  471. + f"(1, {height}, {width}, {channels}) or "
  472. + f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
  473. )
  474. elif isinstance(controlnet_hint, PIL.Image.Image):
  475. if controlnet_hint.size == (width, height):
  476. controlnet_hint = controlnet_hint.convert("RGB") # make sure 3 channel RGB format
  477. controlnet_hint = np.array(controlnet_hint) # to numpy
  478. controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
  479. return self.controlnet_hint_conversion(controlnet_hint, height, width, num_images_per_prompt)
  480. else:
  481. raise ValueError(
  482. f"Acceptable image size of `controlnet_hint` is ({width}, {height}) but is {controlnet_hint.size}"
  483. )
  484. else:
  485. raise ValueError(
  486. f"Acceptable type of `controlnet_hint` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
  487. )
  488. def prepare_mask_latents(
  489. self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
  490. ):
  491. # resize the mask to latents shape as we concatenate the mask to the latents
  492. # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
  493. # and half precision
  494. mask = torch.nn.functional.interpolate(
  495. mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
  496. )
  497. mask = mask.to(device=device, dtype=dtype)
  498. masked_image = masked_image.to(device=device, dtype=dtype)
  499. # encode the mask image into latents space so we can concatenate it to the latents
  500. if isinstance(generator, list):
  501. masked_image_latents = [
  502. self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
  503. for i in range(batch_size)
  504. ]
  505. masked_image_latents = torch.cat(masked_image_latents, dim=0)
  506. else:
  507. masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
  508. masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
  509. # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
  510. if mask.shape[0] < batch_size:
  511. if not batch_size % mask.shape[0] == 0:
  512. raise ValueError(
  513. "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
  514. f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
  515. " of masks that you pass is divisible by the total requested batch size."
  516. )
  517. mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
  518. if masked_image_latents.shape[0] < batch_size:
  519. if not batch_size % masked_image_latents.shape[0] == 0:
  520. raise ValueError(
  521. "The passed images and the required batch size don't match. Images are supposed to be duplicated"
  522. f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
  523. " Make sure the number of images that you pass is divisible by the total requested batch size."
  524. )
  525. masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
  526. mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
  527. masked_image_latents = (
  528. torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
  529. )
  530. # aligning device to prevent device errors when concating it with the latent model input
  531. masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
  532. return mask, masked_image_latents
  533. @torch.no_grad()
  534. @replace_example_docstring(EXAMPLE_DOC_STRING)
  535. def __call__(
  536. self,
  537. prompt: Union[str, List[str]] = None,
  538. height: Optional[int] = None,
  539. width: Optional[int] = None,
  540. num_inference_steps: int = 50,
  541. guidance_scale: float = 7.5,
  542. negative_prompt: Optional[Union[str, List[str]]] = None,
  543. num_images_per_prompt: Optional[int] = 1,
  544. eta: float = 0.0,
  545. generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
  546. latents: Optional[torch.FloatTensor] = None,
  547. prompt_embeds: Optional[torch.FloatTensor] = None,
  548. negative_prompt_embeds: Optional[torch.FloatTensor] = None,
  549. output_type: Optional[str] = "pil",
  550. return_dict: bool = True,
  551. callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
  552. callback_steps: Optional[int] = 1,
  553. cross_attention_kwargs: Optional[Dict[str, Any]] = None,
  554. controlnet_hint: Optional[Union[torch.FloatTensor, np.ndarray, PIL.Image.Image]] = None,
  555. image: Union[torch.FloatTensor, PIL.Image.Image] = None,
  556. mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
  557. ):
  558. r"""
  559. Function invoked when calling the pipeline for generation.
  560. Args:
  561. prompt (`str` or `List[str]`, *optional*):
  562. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
  563. instead.
  564. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
  565. The height in pixels of the generated image.
  566. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
  567. The width in pixels of the generated image.
  568. num_inference_steps (`int`, *optional*, defaults to 50):
  569. The number of denoising steps. More denoising steps usually lead to a higher quality image at the
  570. expense of slower inference.
  571. guidance_scale (`float`, *optional*, defaults to 7.5):
  572. Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
  573. `guidance_scale` is defined as `w` of equation 2. of [Imagen
  574. Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
  575. 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
  576. usually at the expense of lower image quality.
  577. negative_prompt (`str` or `List[str]`, *optional*):
  578. The prompt or prompts not to guide the image generation. If not defined, one has to pass
  579. `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
  580. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
  581. num_images_per_prompt (`int`, *optional*, defaults to 1):
  582. The number of images to generate per prompt.
  583. eta (`float`, *optional*, defaults to 0.0):
  584. Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
  585. [`schedulers.DDIMScheduler`], will be ignored for others.
  586. generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
  587. One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
  588. to make generation deterministic.
  589. latents (`torch.FloatTensor`, *optional*):
  590. Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
  591. generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
  592. tensor will ge generated by sampling using the supplied random `generator`.
  593. prompt_embeds (`torch.FloatTensor`, *optional*):
  594. Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
  595. provided, text embeddings will be generated from `prompt` input argument.
  596. negative_prompt_embeds (`torch.FloatTensor`, *optional*):
  597. Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
  598. weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
  599. argument.
  600. output_type (`str`, *optional*, defaults to `"pil"`):
  601. The output format of the generate image. Choose between
  602. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
  603. return_dict (`bool`, *optional*, defaults to `True`):
  604. Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
  605. plain tuple.
  606. callback (`Callable`, *optional*):
  607. A function that will be called every `callback_steps` steps during inference. The function will be
  608. called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
  609. callback_steps (`int`, *optional*, defaults to 1):
  610. The frequency at which the `callback` function will be called. If not specified, the callback will be
  611. called at every step.
  612. cross_attention_kwargs (`dict`, *optional*):
  613. A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
  614. `self.processor` in
  615. [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
  616. controlnet_hint (`torch.FloatTensor`, `np.ndarray` or `PIL.Image.Image`, *optional*):
  617. ControlNet input embedding. ControlNet generates guidances using this input embedding. If the type is
  618. specified as `torch.FloatTensor`, it is passed to ControlNet as is. If the type is `np.ndarray`, it is
  619. assumed to be an OpenCV compatible image format. PIL.Image.Image` can also be accepted as an image. The
  620. size of all these types must correspond to the output image size.
  621. Examples:
  622. Returns:
  623. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
  624. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
  625. When returning a tuple, the first element is a list with the generated images, and the second element is a
  626. list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
  627. (nsfw) content, according to the `safety_checker`.
  628. """
  629. # 0. Default height and width to unet
  630. height = height or self.unet.config.sample_size * self.vae_scale_factor
  631. width = width or self.unet.config.sample_size * self.vae_scale_factor
  632. # 1. Control Embedding check & conversion
  633. if controlnet_hint is not None:
  634. controlnet_hint = self.controlnet_hint_conversion(controlnet_hint, height, width, num_images_per_prompt)
  635. # 2. Check inputs. Raise error if not correct
  636. self.check_inputs(
  637. prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
  638. )
  639. # 3. Define call parameters
  640. if prompt is not None and isinstance(prompt, str):
  641. batch_size = 1
  642. elif prompt is not None and isinstance(prompt, list):
  643. batch_size = len(prompt)
  644. else:
  645. batch_size = prompt_embeds.shape[0]
  646. device = self._execution_device
  647. # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
  648. # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
  649. # corresponds to doing no classifier free guidance.
  650. do_classifier_free_guidance = guidance_scale > 1.0
  651. # 4. Encode input prompt
  652. prompt_embeds = self._encode_prompt(
  653. prompt,
  654. device,
  655. num_images_per_prompt,
  656. do_classifier_free_guidance,
  657. negative_prompt,
  658. prompt_embeds=prompt_embeds,
  659. negative_prompt_embeds=negative_prompt_embeds,
  660. )
  661. mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
  662. # 5. Prepare timesteps
  663. self.scheduler.set_timesteps(num_inference_steps, device=device)
  664. timesteps = self.scheduler.timesteps
  665. # 6. Prepare latent variables
  666. num_channels_latents = self.unet.in_channels
  667. latents = self.prepare_latents(
  668. batch_size * num_images_per_prompt,
  669. num_channels_latents,
  670. height,
  671. width,
  672. prompt_embeds.dtype,
  673. device,
  674. generator,
  675. latents,
  676. )
  677. mask, masked_image_latents = self.prepare_mask_latents(
  678. mask,
  679. masked_image,
  680. batch_size * num_images_per_prompt,
  681. height,
  682. width,
  683. prompt_embeds.dtype,
  684. device,
  685. generator,
  686. do_classifier_free_guidance,
  687. )
  688. num_channels_mask = mask.shape[1]
  689. num_channels_masked_image = masked_image_latents.shape[1]
  690. # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
  691. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
  692. # 8. Denoising loop
  693. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
  694. with self.progress_bar(total=num_inference_steps) as progress_bar:
  695. for i, t in enumerate(timesteps):
  696. # expand the latents if we are doing classifier free guidance
  697. latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
  698. latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
  699. if controlnet_hint is not None:
  700. # ControlNet predict the noise residual
  701. control = self.controlnet(
  702. latent_model_input, t, encoder_hidden_states=prompt_embeds, controlnet_hint=controlnet_hint
  703. )
  704. control = [item for item in control]
  705. latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
  706. noise_pred = self.unet(
  707. latent_model_input,
  708. t,
  709. encoder_hidden_states=prompt_embeds,
  710. cross_attention_kwargs=cross_attention_kwargs,
  711. control=control,
  712. ).sample
  713. else:
  714. # predict the noise residual
  715. latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
  716. noise_pred = self.unet(
  717. latent_model_input,
  718. t,
  719. encoder_hidden_states=prompt_embeds,
  720. cross_attention_kwargs=cross_attention_kwargs,
  721. ).sample
  722. # perform guidance
  723. if do_classifier_free_guidance:
  724. noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
  725. noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
  726. # compute the previous noisy sample x_t -> x_t-1
  727. latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
  728. # call the callback, if provided
  729. if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
  730. progress_bar.update()
  731. if callback is not None and i % callback_steps == 0:
  732. callback(i, t, latents)
  733. if output_type == "latent":
  734. image = latents
  735. has_nsfw_concept = None
  736. elif output_type == "pil":
  737. # 8. Post-processing
  738. image = self.decode_latents(latents)
  739. # 9. Run safety checker
  740. image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
  741. # 10. Convert to PIL
  742. image = self.numpy_to_pil(image)
  743. else:
  744. # 8. Post-processing
  745. image = self.decode_latents(latents)
  746. # 9. Run safety checker
  747. image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
  748. if not return_dict:
  749. return (image, has_nsfw_concept)
  750. return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

把文件放置在:

以下项目下的每个__init__.py需要把新加的类添加进去

添加信息可以参考文件里面的写法,一般如下:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/457976
推荐阅读
相关标签
  

闽ICP备14008679号