当前位置:   article > 正文

Stable Diffusion如何实现API切换模型_stable diffusion 脚本切换模型源码

stable diffusion 脚本切换模型源码

研究过Stable Diffusion接口文档的小伙伴们肯定知道,文档中并没有提供模型参数,那么如何实现api切换模型呢?




def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs)




  1. def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
  2. self.outpath_samples: str = outpath_samples # 生成的图片的保存路径,和下面的do_not_save_samples配合运用
  3. self.outpath_grids: str = outpath_grids
  4. self.prompt: str = prompt # 正向提示词
  5. self.prompt_for_display: str = None
  6. self.negative_prompt: str = (negative_prompt or "") # 反向提示词
  7. self.styles: list = styles or []
  8. self.seed: int = seed # 种子,-1表明运用随机种子
  9. self.sampler_name: str = sampler_name # 采样方法,比方"DPM++ SDE Karras"
  10. self.batch_size: int = batch_size # 每批生成的数量?
  11. self.n_iter: int = n_iter
  12. self.steps: int = steps # UI中的sampling steps
  13. self.cfg_scale: float = cfg_scale # UI中的CFG Scale,提示词相关性
  14. self.width: int = width # 生成图像的宽度
  15. self.height: int = height # 生成图像的高度
  16. self.restore_faces: bool = restore_faces # 是否运用面部修正
  17. self.tiling: bool = tiling # 是否运用可平铺(tilling)
  18. self.do_not_save_samples: bool = do_not_save_samples




  1. def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
  2. ......
  3. with self.queue_lock:
  4. p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
  5. ......
  6. return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())

从代码中可以看出加载的模型是从shared.sd_model获取的,但是这样加载的模型不是用户维度而是全局的,当我们api传过来的模型与当前模型不一样的时候,我们就需要重新加载模型,那么就需要直接调用modules/sd_models.py中的reload_model_weights(sd_model=None, info=None)函数,咱们只需传入info参数就行,用info参数来指定咱们想要加载的模型,而在这个函数中,会自动判断咱们想要加载的模型和当前模型是否相同,相同的话就不加载。


  1. class CheckpointInfo:
  2. def __init__(self, filename):
  3. self.filename = filename
  4. abspath = os.path.abspath(filename)
  5. if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
  6. name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
  7. elif abspath.startswith(model_path):
  8. name = abspath.replace(model_path, '')
  9. else:
  10. name = os.path.basename(filename)
  11. if name.startswith("\\") or name.startswith("/"):
  12. name = name[1:]
  13. self.name = name
  14. self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
  15. self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
  16. self.hash = model_hash(filename)
  17. self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
  18. self.shorthash = self.sha256[0:10] if self.sha256 else None
  19. self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
  20. self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])


  1. from modules import sd_models
  2. checkpoint_info = sd_models.CheckpointInfo("模型的全路径名称")
  3. sd_models.reload_model_weights(info=checkpoint_info)


1.修改 modules/api/models.py中的StableDiffusionTxt2ImgProcessingAPI增加模型名称

  1. StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
  2. "StableDiffusionProcessingTxt2Img",
  3. StableDiffusionProcessingTxt2Img,
  4. [
  5. {"key": "sampler_index", "type": str, "default": "Euler"},
  6. {"key": "script_name", "type": str, "default": None},
  7. {"key": "script_args", "type": list, "default": []},
  8. {"key": "send_images", "type": bool, "default": True},
  9. {"key": "save_images", "type": bool, "default": False},
  10. {"key": "alwayson_scripts", "type": dict, "default": {}},
  11. {"key": "model_name", "type": str, "default": None},
  12. ]
  13. ).generate_model()


def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '',model_name: str=None, **kwargs):


  1. def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
  2. ......
  3. model_name=txt2imgreq.model_name
  4. if model_name is None:
  5. raise HTTPException(status_code=404, detail="model_name not found")
  6. ......
  7. with self.queue_lock:
  8. checkpoint_info = sd_models.CheckpointInfo(os.path.join(models_path,'Stable-diffusion',model_name))
  9. sd_models.reload_model_weights(info=checkpoint_info)
  10. p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
  11. ......



