当前位置:   article > 正文

要点初见:开源AI绘画工具Stable Diffusion代码分析(文本转图像)、论文介绍(上)_from ldm.util import instantiate_from_config

from ldm.util import instantiate_from_config

博主先前整理并简单介绍了AI绘图工具的部署资源与攻略,觉得其中Stable Diffusion部分不够带劲,故开始试图从论文与代码中一探究竟。前文链接如下:



最近Stable Diffusion实在是太火爆了,在B站上看up主分析论文的视频,分P分析到引言部分就戛然而止,后续视频全是各种整合包的分享与实操,单个视频的播放量也直线上升hhh



上半部分深入分析Stable Diffusion所对应的论文High-Resolution Image Synthesis with Latent Diffusion Models,即《具有潜在扩散模型的高分辨率图像合成》,论文链接如下:https://arxiv.org/pdf/2112.10752.pdf

下半部分深入Stable Diffusion项目代码,代码主要由Python撰写,分析文本转图像部分的代码(模型核心部分将在下篇进行分析)。

一、Stable Diffusion论文分析

《具有潜在扩散模型的高分辨率图像合成》在概述部分将Stable Diffusion分为2个阶段:

第一个阶段为感知压缩阶段(perceptual compression),训练了一个预训练的自编码器(pretrained autoencoders)用于下采样、上采样,自编码器学习到的是一个潜在的空间(latent space),比像素空间小很多,扩散模型在该潜在空间中训练;

第二阶段是扩散模型,语义压缩阶段,在下采样、上采样之间,引入了一个针对文本、边界框、图像的交叉注意力层(cross-attention layers)。

这样整个模型只需要训练降采样和插值之间的部分即可,大大降低了对算力的要求,降低了训练时间,这个模型也被叫做Latent Diffusion Models








通过连接或更通用的交叉注意机制来调节Latent Diffusion Models




二、Stable Diffusion代码分析



  1. import argparse, os, sys, glob
  2. import cv2
  3. import torch
  4. import numpy as np
  5. from omegaconf import OmegaConf
  6. from PIL import Image
  7. from tqdm import tqdm, trange
  8. from imwatermark import WatermarkEncoder
  9. from itertools import islice
  10. from einops import rearrange
  11. from torchvision.utils import make_grid
  12. import time
  13. from pytorch_lightning import seed_everything
  14. from torch import autocast
  15. from contextlib import contextmanager, nullcontext
  16. from ldm.util import instantiate_from_config
  17. from ldm.models.diffusion.ddim import DDIMSampler
  18. from ldm.models.diffusion.plms import PLMSSampler
  19. from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
  20. from transformers import AutoFeatureExtractor




imwatermark(隐形水印添加库,stable diffusion源码中都被添加了“StableDiffusionV1”的隐形水印,不过我看了webui版没有这个);

ldm(Python中的扩散模型库,stable diffusion图像生成的核心);




  1. # load safety model
  2. safety_model_id = "CompVis/stable-diffusion-safety-checker"
  3. safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
  4. safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)

简称NSFW,简单说就是设定是否生成正经的内容,是否避免生成不宜的内容(;OдO)。Stable Diffusion默认设定为SAFE FOR WORK的,通过下述函数实现:

  1. def check_safety(x_image):
  2. safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
  3. x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
  4. assert x_checked_image.shape[0] == len(has_nsfw_concept)
  5. for i in range(len(has_nsfw_concept)):
  6. if has_nsfw_concept[i]:
  7. x_checked_image[i] = load_replacement(x_checked_image[i])
  8. return x_checked_image, has_nsfw_concept


  1. def check_safety(x_image):
  2. # safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
  3. # x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
  4. # assert x_checked_image.shape[0] == len(has_nsfw_concept)
  5. # for i in range(len(has_nsfw_concept)):
  6. # if has_nsfw_concept[i]:
  7. # x_checked_image[i] = load_replacement(x_checked_image[i])
  8. return x_image, False

简单说就是将输入的image不经过check safety直接返回给输出,至于第二个参数因在主函数中后续没有被使用,故随便返回一个False即可。

确认了一下,webui版本的NSFW是默认关闭的,因此网上的NovelAI“咒语”大都把nsfw加入negative tag中。肯定有人有大胆的想法……不,你不想( ̄▽ ̄)/


  1. parser = argparse.ArgumentParser()
  2. parser.add_argument(
  3. "--prompt",
  4. type=str,
  5. nargs="?",
  6. default="a painting of a virus monster playing guitar",
  7. help="the prompt to render"
  8. )
  9. parser.add_argument(
  10. "--outdir",
  11. type=str,
  12. nargs="?",
  13. help="dir to write results to",
  14. default="outputs/txt2img-samples"
  15. )
  16. parser.add_argument(
  17. "--skip_grid",
  18. action='store_true',
  19. help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
  20. )
  21. parser.add_argument(
  22. "--skip_save",
  23. action='store_true',
  24. help="do not save individual samples. For speed measurements.",
  25. )
  26. parser.add_argument(
  27. "--ddim_steps",
  28. type=int,
  29. default=50,
  30. help="number of ddim sampling steps",
  31. )
  32. parser.add_argument(
  33. "--plms",
  34. action='store_true',
  35. help="use plms sampling",
  36. )
  37. parser.add_argument(
  38. "--laion400m",
  39. action='store_true',
  40. help="uses the LAION400M model",
  41. )
  42. parser.add_argument(
  43. "--fixed_code",
  44. action='store_true',
  45. help="if enabled, uses the same starting code across samples ",
  46. )
  47. parser.add_argument(
  48. "--ddim_eta",
  49. type=float,
  50. default=0.0,
  51. help="ddim eta (eta=0.0 corresponds to deterministic sampling",
  52. )
  53. parser.add_argument(
  54. "--n_iter",
  55. type=int,
  56. default=2,
  57. help="sample this often",
  58. )
  59. parser.add_argument(
  60. "--H",
  61. type=int,
  62. default=512,
  63. help="image height, in pixel space",
  64. )
  65. parser.add_argument(
  66. "--W",
  67. type=int,
  68. default=512,
  69. help="image width, in pixel space",
  70. )
  71. parser.add_argument(
  72. "--C",
  73. type=int,
  74. default=4,
  75. help="latent channels",
  76. )
  77. parser.add_argument(
  78. "--f",
  79. type=int,
  80. default=8,
  81. help="downsampling factor",
  82. )
  83. parser.add_argument(
  84. "--n_samples",
  85. type=int,
  86. default=3,
  87. help="how many samples to produce for each given prompt. A.k.a. batch size",
  88. )
  89. parser.add_argument(
  90. "--n_rows",
  91. type=int,
  92. default=0,
  93. help="rows in the grid (default: n_samples)",
  94. )
  95. parser.add_argument(
  96. "--scale",
  97. type=float,
  98. default=7.5,
  99. help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
  100. )
  101. parser.add_argument(
  102. "--from-file",
  103. type=str,
  104. help="if specified, load prompts from this file",
  105. )
  106. parser.add_argument(
  107. "--config",
  108. type=str,
  109. default="configs/stable-diffusion/v1-inference.yaml",
  110. help="path to config which constructs model",
  111. )
  112. parser.add_argument(
  113. "--ckpt",
  114. type=str,
  115. default="models/ldm/stable-diffusion-v1/model.ckpt",
  116. help="path to checkpoint of model",
  117. )
  118. parser.add_argument(
  119. "--seed",
  120. type=int,
  121. default=42,
  122. help="the seed (for reproducible sampling)",
  123. )
  124. parser.add_argument(
  125. "--precision",
  126. type=str,
  127. help="evaluate at this precision",
  128. choices=["full", "autocast"],
  129. default="autocast"
  130. )
  131. opt = parser.parse_args()


  1. usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA]
  2. [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT]
  3. [--seed SEED] [--precision {full,autocast}]
  4. optional arguments:
  5. -h, --help show this help message and exit
  6. --prompt [PROMPT] the prompt to render
  7. --outdir [OUTDIR] dir to write results to
  8. --skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples
  9. --skip_save do not save individual samples. For speed measurements.
  10. --ddim_steps DDIM_STEPS
  11. number of ddim sampling steps
  12. --plms use plms sampling
  13. --laion400m uses the LAION400M model
  14. --fixed_code if enabled, uses the same starting code across samples
  15. --ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling
  16. --n_iter N_ITER sample this often
  17. --H H image height, in pixel space
  18. --W W image width, in pixel space
  19. --C C latent channels
  20. --f F downsampling factor
  21. --n_samples N_SAMPLES
  22. how many samples to produce for each given prompt. A.k.a. batch size
  23. --n_rows N_ROWS rows in the grid (default: n_samples)
  24. --scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))
  25. --from-file FROM_FILE
  26. if specified, load prompts from this file
  27. --config CONFIG path to config which constructs model
  28. --ckpt CKPT path to checkpoint of model
  29. --seed SEED the seed (for reproducible sampling)
  30. --precision {full,autocast}
  31. evaluate at this precision


python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms 


  1. if opt.laion400m:
  2. print("Falling back to LAION 400M model...")
  3. opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
  4. opt.ckpt = "models/ldm/text2img-large/model.ckpt"
  5. opt.outdir = "outputs/txt2img-samples-laion400m"
  6. seed_everything(opt.seed)
  7. config = OmegaConf.load(f"{opt.config}")
  8. model = load_model_from_config(config, f"{opt.ckpt}")
  9. device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  10. model = model.to(device)
  11. if opt.plms:
  12. sampler = PLMSSampler(model)
  13. else:
  14. sampler = DDIMSampler(model)
  15. os.makedirs(opt.outdir, exist_ok=True)
  16. outpath = opt.outdir




  1. print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
  2. wm = "StableDiffusionV1"
  3. wm_encoder = WatermarkEncoder()
  4. wm_encoder.set_watermark('bytes', wm.encode('utf-8'))



img = put_watermark(img, wm_encoder)
  1. def put_watermark(img, wm_encoder=None):
  2. if wm_encoder is not None:
  3. img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
  4. img = wm_encoder.encode(img, 'dwtDct')
  5. img = Image.fromarray(img[:, :, ::-1])
  6. return img

需要关注的是相比于imwatermark对普通png图像的隐形水印添加代码,此处将RGB转为了BGR后才叠加了水印,再通过Image.fromarray(img[:, :, ::-1])转回RGB通道。


  1. #!/usr/bin/env python3
  2. import cv2
  3. from imwatermark import WatermarkDecoder
  4. bgr = cv2.imread('cat_wm.png')
  5. decoder = WatermarkDecoder('bytes', 32)
  6. watermark = decoder.decode(bgr, 'dwtDct')
  7. print(watermark.decode('utf-8'))


  1. batch_size = opt.n_samples
  2. n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
  3. if not opt.from_file:
  4. prompt = opt.prompt
  5. assert prompt is not None
  6. data = [batch_size * [prompt]]
  7. else:
  8. print(f"reading prompts from {opt.from_file}")
  9. with open(opt.from_file, "r") as f:
  10. data = f.read().splitlines()
  11. data = list(chunk(data, batch_size))
  12. sample_path = os.path.join(outpath, "samples")
  13. os.makedirs(sample_path, exist_ok=True)
  14. base_count = len(os.listdir(sample_path))
  15. grid_count = len(os.listdir(outpath)) - 1
  16. start_code = None
  17. if opt.fixed_code:
  18. start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
  19. precision_scope = autocast if opt.precision=="autocast" else nullcontext


class torch.autocast(device_type, enabled=True, **kwargs)


with precision_scope("cuda"):



Positive prompt:
{{alice}}, alice in wonderland, {{{solo}}},1girl,{{delicate face}},vely long hair,blunt_bangs,{{{full body}}},{floating hair}, {looking_at_viewer},open mouth,{looking_at_viewer},open mouth,blue eyes,Blonde_hair,Beautiful eyes,gradient hair,{{white_frilled_dress}},{{white pantyhose}}, {long sleeves},{juliet_sleeves},{puffy sleeves},white hair bow, Skirt pleats, blue dress bow, blue_large_bow,{{{stading}}},{{{arms behind back}}},sleeves past wrists,sleeves past fingers,{forest}, flowering hedge, scenery,Flowery meadow,clear sky,{delicate grassland},{blooming white roses},flying butterfly,shadow,beautiful sky,cumulonimbus,{{absurdres}},incredibly_absurdres, huge_filesize, {best quality},{masterpiece},delicate details,refined rendering,original,official_art, 10s,

Negative prompt:
lowres,highres, worst quality,low quality,normal quality,artbook, game_cg, duplicate,grossproportions,deformed,out of frame,60s,70s,80s,90s,00s, ugly,morbid,mutation,death, kaijuu,mutation,no hunmans.monster girl,arthropod girl,arthropod limbs,tentacles,blood,size difference,sketch,blurry,blurry face,blurry background,blurry foreground, disfigured,extra,extra_arms,extra_ears,extra_breasts,extra_legs,extra_penises,extra_mouth,multiple_arms,multiple_legs,mutilated,tranny,trans,trannsexual,out of frame,poorly drawnhands,extra fingers,mutated hands, poorly drawn face, bad anatomy,bad proportions, extralimbs,more than 2 nipples,extra limbs,bad anatomy,malformed limbs,missing arms,miss finglegs,mutated hands,fused fingers,too many fingers,long neck,bad finglegs,cropped, bad feet,bad anatomy disfigured,malformed mutated,missing limb,malformed hands,

Steps: 50, Sampler: DDIM, CFG scale: 7, Size: 1024x1024

