当前位置:   article > 正文

Stable Diffusion | Gradio界面设计及ComfyUI API调用_gradio comfyui

gradio comfyui

本文基于ComfyUI API编写了类似于webUI的Gradio交互式界面,支持文生图/图生图(SD1.x,SD2.x,SDXL,Stable Cascade),Lora,ControlNet,图生视频(SVD_xt_1_1),图像修复(FaceDetailer),图像放大(Extras),图片/API工作流信息读取(Info)。

1. 在线体验

本文代码已部署到百度飞桨AI Studio平台,以供大家在线体验Stable Diffusion ComfyUI/webUI 原版界面及自制Gradio界面。

项目链接:Stable Diffusion ComfyUI 在线体验

2. 自制Gradio界面展示

Stable Diffusion 文生图/图生图界面:

Stable Cascade 文生图/图生图界面:

Stable Video Diffusion 图生视频界面:

图片放大界面:

图片/API工作流信息读取界面:

3. Gradio界面设计及ComfyUI API调用源码

  1. import io
  2. import json
  3. import os
  4. import random
  5. import re
  6. import subprocess
  7. import sys
  8. import urllib.parse
  9. import uuid
  10. sys.path.append("/home/aistudio/ComfyUI/venv/lib/python3.10/site-packages")
  11. import gradio as gr
  12. import requests
  13. import websocket
  14. from collections import OrderedDict, deque
  15. from PIL import Image
  16. class Default:
  17. # 1表示启用,0表示禁用
  18. design_mode = 0
  19. lora_weight = 0.8
  20. controlnet_num = 5
  21. controlnet_saveimage = 1
  22. facedetailer_num = 3
  23. prompt = "(best quality:1), (high quality:1), detailed/(extreme, highly, ultra/), realistic, 1girl/(beautiful, delicate, perfect/), "
  24. negative_prompt = "(worst quality:1), (low quality:1), (normal quality:1), lowres, signature, blurry, watermark, duplicate, bad link, plump, bad anatomy, extra arms, extra digits, missing finger, bad hands, bad feet, deformed, error, mutation, text"
  25. if design_mode == 1:
  26. width = 64
  27. hight = 64
  28. steps = 2
  29. else:
  30. width = 512
  31. hight = 768
  32. steps = 20
  33. class Initial:
  34. os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
  35. client_id = str(uuid.uuid4())
  36. server_address = "127.0.0.1:8188"
  37. if Default.design_mode == 0:
  38. cmd = "ps -eo pid,args | grep 'export GRADIO_SERVER_PORT=' | awk '{print $8, $14}'"
  39. ps_output = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout.splitlines()
  40. for i in ps_output:
  41. if "/home/aistudio/ComfyUI.gradio.py" in i:
  42. port = i.split(" ")[0].split("=")[1]
  43. server_address = f"127.0.0.1:{port}"
  44. output_dir = os.path.join(os.getcwd(), "ComfyUI/output/")
  45. uploaded_image = {}
  46. class Choices:
  47. ws = websocket.WebSocket()
  48. ws.connect("ws://{}/ws?clientId={}".format(Initial.server_address, Initial.client_id))
  49. object_info = requests.get(url="http://{}/object_info".format(Initial.server_address)).json()
  50. embedding = requests.get(url="http://{}/embeddings".format(Initial.server_address)).json()
  51. ws.close()
  52. ckpt = []
  53. ckpt_list = {}
  54. ckpt_name = object_info["ImageOnlyCheckpointLoader"]["input"]["required"]["ckpt_name"][0]
  55. hidden_ckpt = ["stable_cascade_stage_c.safetensors", "stable_cascade_stage_b.safetensors", "svd_xt_1_1.safetensors", "control_v11p_sd15_canny.safetensors", "control_v11f1p_sd15_depth.safetensors", "control_v11p_sd15_openpose.safetensors"]
  56. for i in ckpt_name:
  57. path, file = os.path.split(i)
  58. if file not in hidden_ckpt:
  59. ckpt.append(file)
  60. ckpt_list[file] = i
  61. ckpt = sorted(ckpt)
  62. controlnet_model = []
  63. controlnet_model_list = {}
  64. controlnet_name = object_info["ControlNetLoader"]["input"]["required"]["control_net_name"][0]
  65. for i in controlnet_name:
  66. path, file = os.path.split(i)
  67. controlnet_model.append(file)
  68. controlnet_model_list[file] = i
  69. controlnet_model = sorted(controlnet_model)
  70. preprocessor = ["Canny"]
  71. if "AIO_Preprocessor" in object_info:
  72. preprocessor = ["none", "Canny", "CannyEdgePreprocessor", "DepthAnythingPreprocessor", "DWPreprocessor", "OpenposePreprocessor"]
  73. for i in sorted(object_info["AIO_Preprocessor"]["input"]["optional"]["preprocessor"][0]):
  74. if i not in preprocessor:
  75. preprocessor.append(i)
  76. if "FaceDetailer" in object_info:
  77. facedetailer_detector_model = []
  78. facedetailer_detector_model_list = {}
  79. facedetailer_detector_model_name = object_info["UltralyticsDetectorProvider"]["input"]["required"]["model_name"][0]
  80. for i in facedetailer_detector_model_name:
  81. path, file = os.path.split(i)
  82. facedetailer_detector_model.append(file)
  83. facedetailer_detector_model_list[file] = i
  84. facedetailer_detector_model = sorted(facedetailer_detector_model)
  85. lora = object_info["LoraLoader"]["input"]["required"]["lora_name"][0]
  86. sampler = object_info["KSampler"]["input"]["required"]["sampler_name"][0]
  87. scheduler = object_info["KSampler"]["input"]["required"]["scheduler"][0]
  88. upscale_method = object_info["ImageScaleBy"]["input"]["required"]["upscale_method"][0]
  89. upscale_model = object_info["UpscaleModelLoader"]["input"]["required"]["model_name"][0]
  90. vae = ["Automatic"]
  91. for i in sorted(object_info["VAELoader"]["input"]["required"]["vae_name"][0]):
  92. vae.append(i)
  93. class Function:
  94. def format_prompt(prompt):
  95. prompt = re.sub(r"\s+,", ",", prompt)
  96. prompt = re.sub(r"\s+", " ", prompt)
  97. prompt = re.sub(",,+", ",", prompt)
  98. prompt = re.sub(",", ", ", prompt)
  99. prompt = re.sub(r"\s+", " ", prompt)
  100. prompt = re.sub(r"^,", "", prompt)
  101. prompt = re.sub(r"^ ", "", prompt)
  102. prompt = re.sub(r" $", "", prompt)
  103. prompt = re.sub(r",$", "", prompt)
  104. prompt = re.sub(": ", ":", prompt)
  105. return prompt
  106. def get_model_path(model_name):
  107. return Choices.ckpt_list[model_name]
  108. def gen_seed(seed):
  109. seed = int(seed)
  110. if seed < 0:
  111. seed = random.randint(0, 18446744073709551615)
  112. if seed > 18446744073709551615:
  113. seed = 18446744073709551615
  114. return seed
  115. def initialize():
  116. Lora.cache = {}
  117. Upscale.cache = {}
  118. UpscaleWithModel.cache = {}
  119. ControlNet.cache = {}
  120. FaceDetailer.cache = {}
  121. def upload_image(image):
  122. buffer = io.BytesIO()
  123. image.save(buffer, format="png")
  124. image = buffer.getbuffer()
  125. image_hash = hash(image.tobytes())
  126. if image_hash in Initial.uploaded_image:
  127. return Initial.uploaded_image[image_hash]
  128. image_name = str(uuid.uuid4()) + ".png"
  129. Initial.uploaded_image[image_hash] = image_name
  130. image_file = {"image": (image_name, image)}
  131. ws = websocket.WebSocket()
  132. ws.connect("ws://{}/ws?clientId={}".format(Initial.server_address, Initial.client_id))
  133. requests.post(url="http://{}/upload/image".format(Initial.server_address), files=image_file)
  134. ws.close()
  135. return image_name
  136. def order_workflow(workflow):
  137. link_list = {}
  138. for node in workflow:
  139. node_link = []
  140. for input in workflow[node]["inputs"]:
  141. if isinstance(workflow[node]["inputs"][input], list):
  142. node_link.append(workflow[node]["inputs"][input][0])
  143. link_list[node] = node_link
  144. in_degree = {v: 0 for v in link_list}
  145. for node in link_list:
  146. for neighbor in link_list[node]:
  147. in_degree[neighbor] += 1
  148. queue = deque([node for node in in_degree if in_degree[node] == 0])
  149. order_list = []
  150. while queue:
  151. node = queue.popleft()
  152. order_list.append(node)
  153. for neighbor in link_list[node]:
  154. in_degree[neighbor] -= 1
  155. if in_degree[neighbor] == 0:
  156. queue.append(neighbor)
  157. order_list = order_list[::-1]
  158. max_nodes = 1000
  159. new_node_id = max_nodes * 10 + 1
  160. workflow_string = json.dumps(workflow)
  161. for node in order_list:
  162. workflow_string = workflow_string.replace(f'"{node}"', f'"{new_node_id}"')
  163. new_node_id += 1
  164. workflow = json.loads(workflow_string)
  165. workflow = OrderedDict(sorted(workflow.items()))
  166. new_node_id = 1
  167. workflow_string = json.dumps(workflow)
  168. for node in workflow:
  169. workflow_string = workflow_string.replace(f'"{node}"', f'"{new_node_id}"')
  170. new_node_id += 1
  171. workflow = json.loads(workflow_string)
  172. for node in workflow:
  173. if "_meta" in workflow[node]:
  174. del workflow[node]["_meta"]
  175. return workflow
  176. def post_interrupt():
  177. Initial.interrupt = True
  178. ws = websocket.WebSocket()
  179. ws.connect("ws://{}/ws?clientId={}".format(Initial.server_address, Initial.client_id))
  180. requests.post(url="http://{}/interrupt".format(Initial.server_address))
  181. ws.close()
  182. def add_embedding(embedding, negative_prompt):
  183. for i in Choices.embedding:
  184. negative_prompt = negative_prompt.replace(f"embedding:{i},", "")
  185. negative_prompt = Function.format_prompt(negative_prompt)
  186. for i in embedding[::-1]:
  187. negative_prompt = f"embedding:{i}, {negative_prompt}"
  188. return negative_prompt
  189. def gen_image(workflow, counter, batch_count, progress):
  190. if counter == 1:
  191. progress(0, desc="Processing...")
  192. if batch_count == 1:
  193. batch_info = ""
  194. else:
  195. batch_info = f"Batch {counter}/{batch_count}: "
  196. workflow = Function.order_workflow(workflow)
  197. current_progress = 0
  198. Initial.interrupt = False
  199. ws = websocket.WebSocket()
  200. ws.connect("ws://{}/ws?clientId={}".format(Initial.server_address, Initial.client_id))
  201. data = {"prompt": workflow, "client_id": Initial.client_id}
  202. prompt_id = requests.post(url="http://{}/prompt".format(Initial.server_address), json=data).json()["prompt_id"]
  203. while True:
  204. try:
  205. ws.settimeout(0.1)
  206. wsrecv = ws.recv()
  207. if isinstance(wsrecv, str):
  208. data = json.loads(wsrecv)["data"]
  209. if "node" in data:
  210. if data["node"] is not None:
  211. if "value" in data and "max" in data:
  212. if data["max"] > 1:
  213. current_progress = data["value"]/data["max"]
  214. progress(current_progress, desc=f"{batch_info}" + workflow[data["node"]]["class_type"] + " " + str(data["value"]) + "/" + str(data["max"]))
  215. else:
  216. progress(current_progress, desc=f"{batch_info}" + workflow[data["node"]]["class_type"])
  217. if data["node"] is None and data["prompt_id"] == prompt_id:
  218. break
  219. else:
  220. continue
  221. except websocket.WebSocketTimeoutException:
  222. if Initial.interrupt is True:
  223. ws.close()
  224. return None, None
  225. history = requests.get(url="http://{}/history/{}".format(Initial.server_address, prompt_id)).json()[prompt_id]
  226. images = []
  227. file_path = ""
  228. for node_id in history["outputs"]:
  229. node_output = history["outputs"][node_id]
  230. if "images" in node_output:
  231. for image in node_output["images"]:
  232. file_path = Initial.output_dir + image["filename"]
  233. data = {"filename": image["filename"], "subfolder": image["subfolder"], "type": image["type"]}
  234. url_values = urllib.parse.urlencode(data)
  235. image_data = requests.get("http://{}/view?{}".format(Initial.server_address, url_values)).content
  236. image = Image.open(io.BytesIO(image_data))
  237. images.append(image)
  238. ws.close()
  239. return images, file_path
  240. def get_gallery_index(evt: gr.SelectData):
  241. return evt.index
  242. def get_image_info(image_pil):
  243. image_info=[]
  244. if image_pil is None:
  245. return
  246. for key, value in image_pil.info.items():
  247. image_info.append(value)
  248. if image_info != []:
  249. image_info = image_info[0]
  250. if image_info == 0:
  251. image_info = "None"
  252. else:
  253. image_info = "None"
  254. return image_info
  255. def send_to(data, index):
  256. if data == [] or data is None:
  257. return None
  258. return data[index]
  259. class Lora:
  260. cache = {}
  261. def add_node(module, workflow, node_id, model_port, clip_port):
  262. for lora in Lora.cache[module]:
  263. strength_model = Lora.cache[module][lora]
  264. strength_clip = Lora.cache[module][lora]
  265. node_id += 1
  266. workflow[str(node_id)] = {"inputs": {"lora_name": lora, "strength_model": strength_model, "strength_clip": strength_clip, "model": model_port, "clip": clip_port}, "class_type": "LoraLoader"}
  267. model_port = [str(node_id), 0]
  268. clip_port = [str(node_id), 1]
  269. return workflow, node_id, model_port, clip_port
  270. def update_cache(module, lora, lora_weight):
  271. if Initial.initialized is False:
  272. Function.initialize()
  273. if lora == []:
  274. Lora.cache[module] = {}
  275. return True, [], gr.update(value="", visible=False)
  276. lora_list = {}
  277. for i in lora_weight.split("<"):
  278. for j in i.split(">"):
  279. if j != "" and ":" in j:
  280. lora_name, weight = j.split(":")
  281. lora_list[lora_name] = weight
  282. lora_weight = ""
  283. Lora.cache[module] = {}
  284. for i in lora:
  285. if i in lora_list:
  286. weight = lora_list[i]
  287. else:
  288. weight = Default.lora_weight
  289. if lora.index(i) == 0:
  290. lora_weight = f"<{i}:{weight}>"
  291. else:
  292. lora_weight = f"{lora_weight}\n\n<{i}:{weight}>"
  293. if weight != "":
  294. weight = float(weight)
  295. Lora.cache[module][i] = weight
  296. return True, gr.update(), gr.update(value=lora_weight, visible=True)
  297. def blocks(module):
  298. module = gr.Textbox(value=module, visible=False)
  299. lora = gr.Dropdown(Choices.lora, label="Lora", multiselect=True, interactive=True)
  300. lora_weight = gr.Textbox(label="Lora weight | Lora 权重", visible=False)
  301. for gr_block in [lora, lora_weight]:
  302. gr_block.change(fn=Lora.update_cache, inputs=[module, lora, lora_weight], outputs=[Initial.initialized, lora, lora_weight])
  303. class Upscale:
  304. cache = {}
  305. def add_node(module, workflow, node_id, image_port):
  306. upscale_method = Upscale.cache[module]["upscale_method"]
  307. scale_by = Upscale.cache[module]["scale_by"]
  308. node_id += 1
  309. workflow[str(node_id)] = {"inputs": {"upscale_method": upscale_method, "scale_by": scale_by, "image": image_port}, "class_type": "ImageScaleBy"}
  310. image_port = [str(node_id), 0]
  311. return workflow, node_id, image_port
  312. def auto_enable(scale_by):
  313. if scale_by > 1:
  314. return True
  315. else:
  316. return False
  317. def update_cache(module, enable, upscale_method, scale_by):
  318. if Initial.initialized is False:
  319. Function.initialize()
  320. if module not in Upscale.cache:
  321. Upscale.cache[module] = {}
  322. if enable is True:
  323. Upscale.cache[module]["upscale_method"] = upscale_method
  324. Upscale.cache[module]["scale_by"] = scale_by
  325. else:
  326. del Upscale.cache[module]
  327. return True
  328. def blocks(module):
  329. module = gr.Textbox(value=module, visible=False)
  330. enable = gr.Checkbox(label="Enable(放大系数大于1后自动启用)")
  331. with gr.Row():
  332. upscale_method = gr.Dropdown(Choices.upscale_method, label="Upscale method | 放大方法", value=Choices.upscale_method[-1])
  333. scale_by = gr.Slider(minimum=1, maximum=8, step=1, label="Scale by | 放大系数", value=1)
  334. scale_by.release(fn=Upscale.auto_enable, inputs=[scale_by], outputs=[enable])
  335. inputs = [module, enable, upscale_method, scale_by]
  336. for gr_block in inputs:
  337. if type(gr_block) is gr.components.slider.Slider:
  338. gr_block.release(fn=Upscale.update_cache, inputs=inputs, outputs=[Initial.initialized])
  339. else:
  340. gr_block.change(fn=Upscale.update_cache, inputs=inputs, outputs=[Initial.initialized])
  341. class UpscaleWithModel:
  342. cache = {}
  343. def add_node(module, workflow, node_id, image_port):
  344. upscale_model = UpscaleWithModel.cache[module]["upscale_model"]
  345. node_id += 1
  346. workflow[str(node_id)] = {"inputs": {"model_name": upscale_model}, "class_type": "UpscaleModelLoader"}
  347. upscale_model_port = [str(node_id), 0]
  348. node_id += 1
  349. workflow[str(node_id)] = {"inputs": {"upscale_model": upscale_model_port, "image": image_port}, "class_type": "ImageUpscaleWithModel"}
  350. image_port = [str(node_id), 0]
  351. return workflow, node_id, image_port
  352. def update_cache(module, enable, upscale_model):
  353. if Initial.initialized is False:
  354. Function.initialize()
  355. if module not in UpscaleWithModel.cache:
  356. UpscaleWithModel.cache[module] = {}
  357. if enable is True:
  358. UpscaleWithModel.cache[module]["upscale_model"] = upscale_model
  359. else:
  360. del UpscaleWithModel.cache[module]
  361. return True
  362. def blocks(module):
  363. module = gr.Textbox(value=module, visible=False)
  364. enable = gr.Checkbox(label="Enable")
  365. upscale_model = gr.Dropdown(Choices.upscale_model, label="Upscale model | 超分模型", value=Choices.upscale_model[0])
  366. inputs = [module, enable, upscale_model]
  367. for gr_block in inputs:
  368. gr_block.change(fn=UpscaleWithModel.update_cache, inputs=inputs, outputs=[Initial.initialized])
  369. class ControlNet:
  370. cache = {}
  371. model_preprocessor_list = {
  372. "control_v11e_sd15_ip2p.safetensors": [],
  373. "control_v11e_sd15_shuffle.safetensors": ["ShufflePreprocessor"],
  374. "control_v11f1e_sd15_tile.bin": ["TilePreprocessor", "TTPlanet_TileGF_Preprocessor", "TTPlanet_TileSimple_Preprocessor"],
  375. "control_v11f1p_sd15_depth.safetensors": ["DepthAnythingPreprocessor", "LeReS-DepthMapPreprocessor", "MiDaS-NormalMapPreprocessor", "MeshGraphormer-DepthMapPreprocessor", "MeshGraphormer+ImpactDetector-DepthMapPreprocessor", "MiDaS-DepthMapPreprocessor", "Zoe_DepthAnythingPreprocessor", "Zoe-DepthMapPreprocessor"],
  376. "control_v11p_sd15_canny.safetensors": ["Canny", "CannyEdgePreprocessor"],
  377. "control_v11p_sd15_inpaint.safetensors": [],
  378. "control_v11p_sd15_lineart.safetensors": ["LineArtPreprocessor", "LineartStandardPreprocessor"],
  379. "control_v11p_sd15_mlsd.safetensors": ["M-LSDPreprocessor"],
  380. "control_v11p_sd15_normalbae.safetensors": ["BAE-NormalMapPreprocessor", "DSINE-NormalMapPreprocessor"],
  381. "control_v11p_sd15_openpose.safetensors": ["DWPreprocessor", "OpenposePreprocessor", "DensePosePreprocessor"],
  382. "control_v11p_sd15_scribble.safetensors": ["ScribblePreprocessor", "Scribble_XDoG_Preprocessor", "Scribble_PiDiNet_Preprocessor", "FakeScribblePreprocessor"],
  383. "control_v11p_sd15_seg.safetensors": ["AnimeFace_SemSegPreprocessor", "OneFormer-COCO-SemSegPreprocessor", "OneFormer-ADE20K-SemSegPreprocessor", "SemSegPreprocessor", "UniFormer-SemSegPreprocessor"],
  384. "control_v11p_sd15_softedge.safetensors": ["HEDPreprocessor", "PiDiNetPreprocessor", "TEEDPreprocessor", "DiffusionEdge_Preprocessor"],
  385. "control_v11p_sd15s2_lineart_anime.safetensors": ["AnimeLineArtPreprocessor", "Manga2Anime_LineArt_Preprocessor"],
  386. "control_scribble.safetensors": ["BinaryPreprocessor"],
  387. "ioclab_sd15_recolor.safetensors": ["ImageLuminanceDetector", "ImageIntensityDetector"],
  388. "control_sd15_animal_openpose_fp16.pth": ["AnimalPosePreprocessor"],
  389. "controlnet_sd21_laion_face_v2.safetensors": ["MediaPipe-FaceMeshPreprocessor"]
  390. }
  391. def add_node(module, counter, workflow, node_id, positive_port, negative_port):
  392. for unit_id in ControlNet.cache[module]:
  393. preprocessor = ControlNet.cache[module][unit_id]["preprocessor"]
  394. model = ControlNet.cache[module][unit_id]["model"]
  395. input_image = ControlNet.cache[module][unit_id]["input_image"]
  396. resolution = ControlNet.cache[module][unit_id]["resolution"]
  397. strength = ControlNet.cache[module][unit_id]["strength"]
  398. start_percent = ControlNet.cache[module][unit_id]["start_percent"]
  399. end_percent = ControlNet.cache[module][unit_id]["end_percent"]
  400. node_id += 1
  401. workflow[str(node_id)] = {"inputs": {"image": input_image, "upload": "image"}, "class_type": "LoadImage"}
  402. image_port = [str(node_id), 0]
  403. if preprocessor == "Canny":
  404. node_id += 1
  405. workflow[str(node_id)] = {"inputs": {"low_threshold": 0.3, "high_threshold": 0.7, "image": image_port}, "class_type": "Canny"}
  406. image_port = [str(node_id), 0]
  407. else:
  408. node_id += 1
  409. workflow[str(node_id)] = {"inputs": {"preprocessor": preprocessor, "resolution": resolution, "image": image_port}, "class_type": "AIO_Preprocessor"}
  410. image_port = [str(node_id), 0]
  411. if counter == 1 and Default.controlnet_saveimage == 1:
  412. node_id += 1
  413. workflow[str(node_id)] = {"inputs": {"filename_prefix": "ControlNet", "images": image_port}, "class_type": "SaveImage"}
  414. node_id += 1
  415. workflow[str(node_id)] = {"inputs": {"control_net_name": model}, "class_type": "ControlNetLoader"}
  416. control_net_port = [str(node_id), 0]
  417. node_id += 1
  418. workflow[str(node_id)] = {"inputs": {"strength": strength, "start_percent": start_percent, "end_percent": end_percent, "positive": positive_port, "negative": negative_port, "control_net": control_net_port, "image": image_port}, "class_type": "ControlNetApplyAdvanced"}
  419. positive_port = [str(node_id), 0]
  420. negative_port = [str(node_id), 1]
  421. return workflow, node_id, positive_port, negative_port
  422. def auto_enable():
  423. return True
  424. def auto_select_model(preprocessor):
  425. for model in Choices.controlnet_model:
  426. if model in ControlNet.model_preprocessor_list:
  427. if preprocessor in ControlNet.model_preprocessor_list[model]:
  428. return gr.update(value=model)
  429. return gr.update(value="未定义/检测到对应的模型,请自行选择!")
  430. def preprocess(unit_id, preview, preprocessor, input_image, resolution, progress=gr.Progress()):
  431. if preview is False or input_image is None:
  432. return
  433. input_image = Function.upload_image(input_image)
  434. workflow = {}
  435. node_id = 1
  436. workflow[str(node_id)] = {"inputs": {"image": input_image, "upload": "image"}, "class_type": "LoadImage"}
  437. image_port = [str(node_id), 0]
  438. if preprocessor == "Canny":
  439. node_id += 1
  440. workflow[str(node_id)] = {"inputs": {"low_threshold": 0.3, "high_threshold": 0.7, "image": image_port}, "class_type": "Canny"}
  441. image_port = [str(node_id), 0]
  442. else:
  443. node_id += 1
  444. workflow[str(node_id)] = {"inputs": {"preprocessor": preprocessor, "resolution": resolution, "image": image_port}, "class_type": "AIO_Preprocessor"}
  445. image_port = [str(node_id), 0]
  446. node_id += 1
  447. workflow[str(node_id)] = {"inputs": {"images": image_port}, "class_type": "PreviewImage"}
  448. output = Function.gen_image(workflow, 1, 1, progress)[0]
  449. if output is not None:
  450. output = output[0]
  451. return output
  452. def update_cache(module, unit_id, enable, preprocessor, model, input_image, resolution, strength, start_percent, end_percent):
  453. if Initial.initialized is False:
  454. Function.initialize()
  455. if module not in ControlNet.cache:
  456. ControlNet.cache[module] = {}
  457. ControlNet.cache[module][unit_id] = {}
  458. if input_image is None:
  459. del ControlNet.cache[module][unit_id]
  460. return True, False
  461. if model not in Choices.controlnet_model:
  462. del ControlNet.cache[module][unit_id]
  463. return True, False
  464. if enable is True:
  465. ControlNet.cache[module][unit_id]["preprocessor"] = preprocessor
  466. ControlNet.cache[module][unit_id]["model"] = Choices.controlnet_model_list[model]
  467. ControlNet.cache[module][unit_id]["input_image"] = Function.upload_image(input_image)
  468. ControlNet.cache[module][unit_id]["resolution"] = resolution
  469. ControlNet.cache[module][unit_id]["strength"] = strength
  470. ControlNet.cache[module][unit_id]["start_percent"] = start_percent
  471. ControlNet.cache[module][unit_id]["end_percent"] = end_percent
  472. else:
  473. del ControlNet.cache[module][unit_id]
  474. return True, gr.update()
  475. def unit(module, i):
  476. module = gr.Textbox(value=module, visible=False)
  477. unit_id = gr.Textbox(value=i, visible=False)
  478. with gr.Row():
  479. enable = gr.Checkbox(label="Enable(上传图片后自动启用)")
  480. preview = gr.Checkbox(label="Preview")
  481. with gr.Row():
  482. preprocessor = gr.Dropdown(Choices.preprocessor, label="Preprocessor", value="Canny")
  483. model = gr.Dropdown(Choices.controlnet_model, label="ControlNet model", value="control_v11p_sd15_canny.safetensors")
  484. with gr.Row():
  485. input_image = gr.Image(type="pil")
  486. preprocess_preview = gr.Image(label="Preprocessor preview")
  487. with gr.Row():
  488. resolution = gr.Slider(label="Resolution", minimum=64, maximum=2048, step=64, value=512)
  489. strength = gr.Slider(label="Strength", minimum=0, maximum=2, step=0.01, value=1)
  490. with gr.Row():
  491. start_percent = gr.Slider(label="Start percent", minimum=0, maximum=1, step=0.01, value=0)
  492. end_percent = gr.Slider(label="End percent", minimum=0, maximum=1, step=0.01, value=1)
  493. input_image.upload(fn=ControlNet.auto_enable, inputs=None, outputs=[enable])
  494. preprocessor.change(fn=ControlNet.auto_select_model, inputs=[preprocessor], outputs=[model])
  495. for gr_block in [preview, preprocessor, input_image]:
  496. gr_block.change(fn=ControlNet.preprocess, inputs=[unit_id, preview, preprocessor, input_image, resolution], outputs=[preprocess_preview])
  497. inputs = [module, unit_id, enable, preprocessor, model, input_image, resolution, strength, start_percent, end_percent]
  498. for gr_block in inputs:
  499. if type(gr_block) is gr.components.slider.Slider:
  500. gr_block.release(fn=ControlNet.update_cache, inputs=inputs, outputs=[Initial.initialized, enable])
  501. else:
  502. gr_block.change(fn=ControlNet.update_cache, inputs=inputs, outputs=[Initial.initialized, enable])
  503. def blocks(module):
  504. with gr.Tab(label="控制网络"):
  505. if Default.controlnet_num == 1:
  506. ControlNet.unit(module, 1)
  507. else:
  508. for i in range(Default.controlnet_num):
  509. with gr.Tab(label=f"ControlNet Unit {i + 1}"):
  510. ControlNet.unit(module, i + 1)
  511. class FaceDetailer:
  512. cache = {}
  513. def add_node(module, workflow, node_id, image_port, model_port, clip_port, vae_port, positive_port, negative_port, seed, steps, cfg, sampler_name, scheduler):
  514. for unit_id in FaceDetailer.cache[module]:
  515. model = Choices.facedetailer_detector_model_list[FaceDetailer.cache[module][unit_id]["model"]]
  516. node_id += 1
  517. workflow[str(node_id)] = {"inputs": {"model_name": model}, "class_type": "UltralyticsDetectorProvider"}
  518. bbox_detector_port = [str(node_id), 0]
  519. segm_detector_opt_port = [str(node_id), 1]
  520. node_id += 1
  521. workflow[str(node_id)] = {"inputs": {"model_name": "sam_vit_b_01ec64.pth", "device_mode": "AUTO"}, "class_type": "SAMLoader"}
  522. sam_model_opt_port = [str(node_id), 0]
  523. node_id += 1
  524. workflow[str(node_id)] = {"inputs": {"guide_size": 384, "guide_size_for": "True", "max_size": 1024, "seed": seed, "steps": steps, "cfg": cfg, "sampler_name": sampler_name, "scheduler": scheduler, "denoise": 0.5, "feather": 5, "noise_mask": "True", "force_inpaint": "True", "bbox_threshold": 0.5, "bbox_dilation": 10, "bbox_crop_factor": 3, "sam_detection_hint": "center-1", "sam_dilation": 0, "sam_threshold": 0.93, "sam_bbox_expansion": 0, "sam_mask_hint_threshold": 0.7, "sam_mask_hint_use_negative": "False", "drop_size": 10, "wildcard": "", "cycle": 1, "inpaint_model": "False", "noise_mask_feather": 20, "image": image_port, "model": model_port, "clip": clip_port, "vae": vae_port, "positive": positive_port, "negative": negative_port, "bbox_detector": bbox_detector_port, "sam_model_opt": sam_model_opt_port, "segm_detector_opt": segm_detector_opt_port}, "class_type": "FaceDetailer"}
  525. image_port = [str(node_id), 0]
  526. return workflow, node_id, image_port
  527. def update_cache(module, unit_id, enable, model):
  528. if Initial.initialized is False:
  529. Function.initialize()
  530. if module not in FaceDetailer.cache:
  531. FaceDetailer.cache[module] = {}
  532. FaceDetailer.cache[module][unit_id] = {}
  533. if enable is True:
  534. FaceDetailer.cache[module][unit_id]["model"] = model
  535. else:
  536. del FaceDetailer.cache[module][unit_id]
  537. return True
  538. def unit(module, i):
  539. module = gr.Textbox(value=module, visible=False)
  540. unit_id = gr.Textbox(value=i, visible=False)
  541. enable = gr.Checkbox(label="Enable")
  542. if i == 1:
  543. model = gr.Dropdown(Choices.facedetailer_detector_model, label="Detector model", value="face_yolov8m.pt")
  544. if i == 2:
  545. model = gr.Dropdown(Choices.facedetailer_detector_model, label="Detector model", value="hand_yolov8s.pt")
  546. if i == 3:
  547. model = gr.Dropdown(Choices.facedetailer_detector_model, label="Detector model", value="person_yolov8m-seg.pt")
  548. inputs = [module, unit_id, enable, model]
  549. for gr_block in inputs:
  550. gr_block.change(fn=FaceDetailer.update_cache, inputs=inputs, outputs=[Initial.initialized])
  551. def blocks(module):
  552. with gr.Tab(label="图像修复"):
  553. if Default.facedetailer_num == 1:
  554. FaceDetailer.unit(module, 1)
  555. else:
  556. with gr.Row():
  557. for i in range(Default.facedetailer_num):
  558. with gr.Column():
  559. with gr.Tab(label=f"FaceDetailer Unit {i + 1}"):
  560. FaceDetailer.unit(module, i + 1)
  561. if Default.facedetailer_num % 2 != 0:
  562. with gr.Column():
  563. gr.HTML("")
  564. class Postprocess:
  565. def add_node(module, *args):
  566. if module == "SD":
  567. workflow, node_id, image_port, model_port, clip_port, vae_port, positive_port, negative_port, seed, steps, cfg, sampler_name, scheduler = args
  568. else:
  569. workflow, node_id, image_port = args
  570. if module in FaceDetailer.cache:
  571. workflow, node_id, image_port = FaceDetailer.add_node(module, workflow, node_id, image_port, model_port, clip_port, vae_port, positive_port, negative_port, seed, steps, cfg, sampler_name, scheduler)
  572. if module in Upscale.cache:
  573. workflow, node_id, image_port = Upscale.add_node(module, workflow, node_id, image_port)
  574. if module in UpscaleWithModel.cache:
  575. workflow, node_id, image_port = UpscaleWithModel.add_node(module, workflow, node_id, image_port)
  576. return workflow, node_id, image_port
  577. def blocks(module):
  578. if module == "SD":
  579. if "FaceDetailer" in Choices.object_info:
  580. FaceDetailer.blocks(module)
  581. with gr.Tab(label="图像放大"):
  582. with gr.Row():
  583. with gr.Tab(label="算术放大"):
  584. Upscale.blocks(module)
  585. with gr.Row():
  586. with gr.Tab(label="超分放大"):
  587. UpscaleWithModel.blocks(module)
  588. gr.HTML("注意:同时启用两种放大模式将先执行算术放大,再执行超分放大,最终放大倍数为二者放大倍数的乘积!")
  589. class SD:
  590. def generate(initialized, batch_count, ckpt_name, vae_name, clip_mode, clip_skip, width, height, batch_size, negative_prompt, positive_prompt, seed, steps, cfg, sampler_name, scheduler, denoise, input_image, progress=gr.Progress()):
  591. module = "SD"
  592. ckpt_name = Function.get_model_path(ckpt_name)
  593. seed = Function.gen_seed(seed)
  594. if input_image is not None:
  595. input_image = Function.upload_image(input_image)
  596. counter = 1
  597. output_images = []
  598. node_id = 0
  599. while counter <= batch_count:
  600. workflow = {}
  601. node_id += 1
  602. workflow[str(node_id)] = {"inputs": {"ckpt_name": ckpt_name}, "class_type": "CheckpointLoaderSimple"}
  603. model_port = [str(node_id), 0]
  604. clip_port = [str(node_id), 1]
  605. vae_port = [str(node_id), 2]
  606. if vae_name != "Automatic":
  607. node_id += 1
  608. workflow[str(node_id)] = {"inputs": {"vae_name": vae_name}, "class_type": "VAELoader"}
  609. vae_port = [str(node_id), 0]
  610. if input_image is None:
  611. node_id += 1
  612. workflow[str(node_id)] = {"inputs": {"width": width, "height": height, "batch_size": batch_size}, "class_type": "EmptyLatentImage"}
  613. latent_image_port = [str(node_id), 0]
  614. else:
  615. node_id += 1
  616. workflow[str(node_id)] = {"inputs": {"image": input_image, "upload": "image"}, "class_type": "LoadImage"}
  617. pixels_port = [str(node_id), 0]
  618. node_id += 1
  619. workflow[str(node_id)] = {"inputs": {"pixels": pixels_port, "vae": vae_port}, "class_type": "VAEEncode"}
  620. latent_image_port = [str(node_id), 0]
  621. node_id += 1
  622. workflow[str(node_id)] = {"inputs": {"stop_at_clip_layer": -clip_skip, "clip": clip_port}, "class_type": "CLIPSetLastLayer"}
  623. clip_port = [str(node_id), 0]
  624. if initialized is True and module in Lora.cache:
  625. workflow, node_id, model_port, clip_port = Lora.add_node(module, workflow, node_id, model_port, clip_port)
  626. node_id += 1
  627. if clip_mode == "ComfyUI":
  628. workflow[str(node_id)] = {"inputs": {"text": positive_prompt, "clip": clip_port}, "class_type": "CLIPTextEncode"}
  629. else:
  630. workflow[str(node_id)] = {"inputs": {"text": positive_prompt, "token_normalization": "none", "weight_interpretation": "A1111", "clip": clip_port}, "class_type": "BNK_CLIPTextEncodeAdvanced"}
  631. positive_port = [str(node_id), 0]
  632. node_id += 1
  633. if clip_mode == "ComfyUI":
  634. workflow[str(node_id)] = {"inputs": {"text": negative_prompt, "clip": clip_port}, "class_type": "CLIPTextEncode"}
  635. else:
  636. workflow[str(node_id)] = {"inputs": {"text": negative_prompt, "token_normalization": "none", "weight_interpretation": "A1111", "clip": clip_port}, "class_type": "BNK_CLIPTextEncodeAdvanced"}
  637. negative_port = [str(node_id), 0]
  638. if initialized is True and module in ControlNet.cache:
  639. workflow, node_id, positive_port, negative_port = ControlNet.add_node(module, counter, workflow, node_id, positive_port, negative_port)
  640. node_id += 1
  641. workflow[str(node_id)] = {"inputs": {"seed": seed, "steps": steps, "cfg": cfg, "sampler_name": sampler_name, "scheduler": scheduler, "denoise": denoise, "model": model_port, "positive": positive_port, "negative": negative_port, "latent_image": latent_image_port}, "class_type": "KSampler"}
  642. samples_port = [str(node_id), 0]
  643. node_id += 1
  644. workflow[str(node_id)] = {"inputs": {"samples": samples_port, "vae": vae_port}, "class_type": "VAEDecode"}
  645. image_port = [str(node_id), 0]
  646. if initialized is True:
  647. workflow, node_id, image_port = Postprocess.add_node(module, workflow, node_id, image_port, model_port, clip_port, vae_port, positive_port, negative_port, seed, steps, cfg, sampler_name, scheduler)
  648. node_id += 1
  649. workflow[str(node_id)] = {"inputs": {"filename_prefix": "ComfyUI", "images": image_port}, "class_type": "SaveImage"}
  650. images = Function.gen_image(workflow, counter, batch_count, progress)[0]
  651. if images is None:
  652. break
  653. for image in images:
  654. output_images.append(image)
  655. seed += 1
  656. counter += 1
  657. return output_images, output_images
  658. def blocks():
  659. with gr.Row():
  660. with gr.Column():
  661. positive_prompt = gr.Textbox(placeholder="Positive prompt | 正向提示词", show_label=False, value=Default.prompt, lines=3)
  662. negative_prompt = gr.Textbox(placeholder="Negative prompt | 负向提示词", show_label=False, value=Default.negative_prompt, lines=3)
  663. with gr.Tab(label="基础设置"):
  664. with gr.Row():
  665. ckpt_name = gr.Dropdown(Choices.ckpt, label="Ckpt name | Ckpt 模型名称", value=Choices.ckpt[0])
  666. vae_name = gr.Dropdown(Choices.vae, label="VAE name | VAE 模型名称", value=Choices.vae[0])
  667. if "BNK_CLIPTextEncodeAdvanced" in Choices.object_info:
  668. clip_mode = gr.Dropdown(["ComfyUI", "WebUI"], label="Clip 编码类型", value="ComfyUI")
  669. else:
  670. clip_mode = gr.Dropdown(["ComfyUI", "WebUI"], label="Clip 编码类型", value="ComfyUI", visible=False)
  671. clip_skip = gr.Slider(minimum=1, maximum=12, step=1, label="Clip 跳过", value=1)
  672. with gr.Row():
  673. width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width | 图像宽度", value=Default.width)
  674. batch_size = gr.Slider(minimum=1, maximum=8, step=1, label="Batch size | 批次大小", value=1)
  675. with gr.Row():
  676. height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height | 图像高度", value=Default.hight)
  677. batch_count = gr.Slider(minimum=1, maximum=100, step=1, label="Batch count | 生成批次", value=1)
  678. with gr.Row():
  679. if Choices.lora != []:
  680. Lora.blocks("SD")
  681. if Choices.embedding != []:
  682. embedding = gr.Dropdown(Choices.embedding, label="Embedding", multiselect=True, interactive=True)
  683. embedding.change(fn=Function.add_embedding, inputs=[embedding, negative_prompt], outputs=[negative_prompt])
  684. with gr.Row():
  685. SD.input_image = gr.Image(value=None, type="pil")
  686. gr.HTML("<br>上传图片即自动转为图生图模式。<br><br>文生图、图生图模式共享设置参数。<br><br>图像宽度、图像高度、批次大小对图生图无效。")
  687. with gr.Tab(label="采样设置"):
  688. with gr.Row():
  689. sampler_name = gr.Dropdown(Choices.sampler, label="Sampling method | 采样方法", value=Choices.sampler[12])
  690. scheduler = gr.Dropdown(Choices.scheduler, label="Schedule type | 采样计划表类型", value=Choices.scheduler[1])
  691. with gr.Row():
  692. denoise = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
  693. steps = gr.Slider(minimum=1, maximum=100, step=1, label="Sampling steps | 采样次数", value=Default.steps)
  694. with gr.Row():
  695. cfg = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=7)
  696. seed = gr.Slider(minimum=-1, maximum=18446744073709550000, step=1, label="Seed | 种子数", value=-1)
  697. if Choices.controlnet_model != []:
  698. ControlNet.blocks("SD")
  699. Postprocess.blocks("SD")
  700. with gr.Column():
  701. with gr.Row():
  702. btn = gr.Button("Generate | 生成", elem_id="button")
  703. btn2 = gr.Button("Interrupt | 终止")
  704. output = gr.Gallery(preview=True, height=600)
  705. with gr.Row():
  706. SD.send_to_sd = gr.Button("发送图片至 SD")
  707. if SC.enable is True:
  708. SD.send_to_sc = gr.Button("发送图片至 SC")
  709. if SVD.enable is True:
  710. SD.send_to_svd = gr.Button("发送图片至 SVD")
  711. SD.send_to_extras = gr.Button("发送图片至 Extras")
  712. SD.send_to_info = gr.Button("发送图片至 Info")
  713. SD.data = gr.State()
  714. SD.index = gr.State()
  715. btn.click(fn=SD.generate, inputs=[Initial.initialized, batch_count, ckpt_name, vae_name, clip_mode, clip_skip, width, height, batch_size, negative_prompt, positive_prompt, seed, steps, cfg, sampler_name, scheduler, denoise, SD.input_image], outputs=[output, SD.data])
  716. btn2.click(fn=Function.post_interrupt, inputs=None, outputs=None)
  717. output.select(fn=Function.get_gallery_index, inputs=None, outputs=[SD.index])
  718. class SC:
  719. if Default.design_mode == 1:
  720. enable = True
  721. elif "stable_cascade_stage_c.safetensors" in Choices.ckpt_list and "stable_cascade_stage_b.safetensors" in Choices.ckpt_list:
  722. enable = True
  723. else:
  724. enable = False
  725. def generate(initialized, batch_count, positive_prompt, negative_prompt, width, height, batch_size, seed_c, steps_c, cfg_c, sampler_name_c, scheduler_c, denoise_c, seed_b, steps_b, cfg_b, sampler_name_b, scheduler_b, denoise_b, input_image, progress=gr.Progress()):
  726. module = "SC"
  727. ckpt_name_c = Function.get_model_path("stable_cascade_stage_c.safetensors")
  728. ckpt_name_b = Function.get_model_path("stable_cascade_stage_b.safetensors")
  729. seed_c = Function.gen_seed(seed_c)
  730. seed_b = Function.gen_seed(seed_b)
  731. if input_image is not None:
  732. input_image = Function.upload_image(input_image)
  733. counter = 1
  734. output_images = []
  735. while counter <= batch_count:
  736. workflow = {
  737. "1": {"inputs": {"ckpt_name": ckpt_name_c}, "class_type": "CheckpointLoaderSimple"},
  738. "2": {"inputs": {"image": input_image, "upload": "image"}, "class_type": "LoadImage"},
  739. "3": {"inputs": {"compression": 42, "image": ["2", 0], "vae": ["1", 2]}, "class_type": "StableCascade_StageC_VAEEncode"},
  740. "4": {"inputs": {"text": negative_prompt, "clip": ["1", 1]}, "class_type": "CLIPTextEncode"},
  741. "5": {"inputs": {"text": positive_prompt, "clip": ["1", 1]}, "class_type": "CLIPTextEncode"},
  742. "6": {"inputs": {"seed": seed_c, "steps": steps_c, "cfg": cfg_c, "sampler_name": sampler_name_c, "scheduler": scheduler_c, "denoise": denoise_c, "model": ["1", 0], "positive": ["5", 0], "negative": ["4", 0], "latent_image": ["3", 0]}, "class_type": "KSampler"},
  743. "7": {"inputs": {"conditioning": ["5", 0], "stage_c": ["6", 0]}, "class_type": "StableCascade_StageB_Conditioning"},
  744. "8": {"inputs": {"ckpt_name": ckpt_name_b}, "class_type": "CheckpointLoaderSimple"},
  745. "9": {"inputs": {"seed": seed_b, "steps": steps_b, "cfg": cfg_b, "sampler_name": sampler_name_b, "scheduler": scheduler_b, "denoise": denoise_b, "model": ["8", 0], "positive": ["7", 0], "negative": ["4", 0], "latent_image": ["3", 1]}, "class_type": "KSampler"},
  746. "10": {"inputs": {"samples": ["9", 0], "vae": ["8", 2]}, "class_type": "VAEDecode"}
  747. }
  748. if input_image is None:
  749. del workflow["2"]
  750. workflow["3"] = {"inputs": {"width": width, "height": height, "compression": 42, "batch_size": batch_size}, "class_type": "StableCascade_EmptyLatentImage"}
  751. node_id = 10
  752. image_port = [str(node_id), 0]
  753. if initialized is True:
  754. workflow, node_id, image_port = Postprocess.add_node(module, workflow, node_id, image_port)
  755. node_id += 1
  756. workflow[str(node_id)] = {"inputs": {"filename_prefix": "ComfyUI", "images": image_port}, "class_type": "SaveImage"}
  757. images = Function.gen_image(workflow, counter, batch_count, progress)[0]
  758. if images is None:
  759. break
  760. for image in images:
  761. output_images.append(image)
  762. seed_c += 1
  763. counter += 1
  764. return output_images, output_images
  765. def blocks():
  766. with gr.Row():
  767. with gr.Column():
  768. positive_prompt = gr.Textbox(placeholder="Positive prompt | 正向提示词", show_label=False, value=Default.prompt, lines=3)
  769. negative_prompt = gr.Textbox(placeholder="Negative prompt | 负向提示词", show_label=False, value=Default.negative_prompt, lines=3)
  770. with gr.Tab(label="基础设置"):
  771. with gr.Row():
  772. width = gr.Slider(minimum=128, maximum=2048, step=128, label="Width | 图像宽度", value=1024)
  773. batch_size = gr.Slider(minimum=1, maximum=8, step=1, label="Batch size | 批次大小", value=1)
  774. with gr.Row():
  775. height = gr.Slider(minimum=128, maximum=2048, step=128, label="Height | 图像高度", value=1024)
  776. batch_count = gr.Slider(minimum=1, maximum=100, step=1, label="Batch count | 生成批次", value=1)
  777. with gr.Row():
  778. SC.input_image = gr.Image(value=None, type="pil")
  779. gr.HTML("<br>上传图片即自动转为图生图模式。<br><br>文生图、图生图模式共享设置参数。<br><br>图像宽度、图像高度、批次大小对图生图无效。")
  780. with gr.Tab(label="Stage C 采样设置"):
  781. with gr.Row():
  782. sampler_name_c = gr.Dropdown(Choices.sampler, label="Sampling method | 采样方法", value=Choices.sampler[12])
  783. scheduler_c = gr.Dropdown(Choices.scheduler, label="Schedule type | 采样计划表类型", value=Choices.scheduler[1])
  784. with gr.Row():
  785. denoise_c = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
  786. steps_c = gr.Slider(minimum=10, maximum=30, step=1, label="Sampling steps | 采样次数", value=20)
  787. with gr.Row():
  788. cfg_c = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=4)
  789. seed_c = gr.Slider(minimum=-1, maximum=18446744073709550000, step=1, label="Seed | 种子数", value=-1)
  790. with gr.Tab(label="Stage B 采样设置"):
  791. with gr.Row():
  792. sampler_name_b = gr.Dropdown(Choices.sampler, label="Sampling method | 采样方法", value=Choices.sampler[12])
  793. scheduler_b = gr.Dropdown(Choices.scheduler, label="Schedule type | 采样计划表类型", value=Choices.scheduler[1])
  794. with gr.Row():
  795. denoise_b = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
  796. steps_b = gr.Slider(minimum=4, maximum=12, step=1, label="Sampling steps | 采样次数", value=10)
  797. with gr.Row():
  798. cfg_b = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=1.1)
  799. seed_b = gr.Slider(minimum=-1, maximum=18446744073709550000, step=1, label="Seed | 种子数", value=-1)
  800. Postprocess.blocks("SC")
  801. with gr.Column():
  802. with gr.Row():
  803. btn = gr.Button("Generate | 生成", elem_id="button")
  804. btn2 = gr.Button("Interrupt | 终止")
  805. output = gr.Gallery(preview=True, height=600)
  806. with gr.Row():
  807. SC.send_to_sd = gr.Button("发送图片至 SD")
  808. SC.send_to_sc = gr.Button("发送图片至 SC")
  809. if SVD.enable is True:
  810. SC.send_to_svd = gr.Button("发送图片至 SVD")
  811. SC.send_to_extras = gr.Button("发送图片至 Extras")
  812. SC.send_to_info = gr.Button("发送图片至 Info")
  813. SC.data = gr.State()
  814. SC.index = gr.State()
  815. btn.click(fn=SC.generate, inputs=[Initial.initialized, batch_count, positive_prompt, negative_prompt, width, height, batch_size, seed_c, steps_c, cfg_c, sampler_name_c, scheduler_c, denoise_c, seed_b, steps_b, cfg_b, sampler_name_b, scheduler_b, denoise_b, SC.input_image], outputs=[output, SC.data])
  816. btn2.click(fn=Function.post_interrupt, inputs=None, outputs=None)
  817. output.select(fn=Function.get_gallery_index, inputs=None, outputs=[SC.index])
  818. class SVD:
  819. if Default.design_mode == 1:
  820. enable = True
  821. elif "svd_xt_1_1.safetensors" in Choices.ckpt_list:
  822. enable = True
  823. else:
  824. enable = False
  825. def generate(input_image, width, height, video_frames, motion_bucket_id, fps, augmentation_level, min_cfg, seed, steps, cfg, sampler_name, scheduler, denoise, fps2, lossless, quality, method, progress=gr.Progress()):
  826. ckpt_name = Function.get_model_path("svd_xt_1_1.safetensors")
  827. seed = Function.gen_seed(seed)
  828. if input_image is None:
  829. return
  830. else:
  831. input_image = Function.upload_image(input_image)
  832. workflow = {
  833. "1": {"inputs": {"ckpt_name": ckpt_name}, "class_type": "ImageOnlyCheckpointLoader"},
  834. "2": {"inputs": {"image": input_image, "upload": "image"}, "class_type": "LoadImage"},
  835. "3": {"inputs": {"width": width, "height": height, "video_frames": video_frames, "motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "clip_vision": ["1", 1], "init_image": ["2", 0], "vae": ["1", 2]}, "class_type": "SVD_img2vid_Conditioning"},
  836. "4": {"inputs": {"min_cfg": min_cfg, "model": ["1", 0]}, "class_type": "VideoLinearCFGGuidance"},
  837. "5": {"inputs": {"seed": seed, "steps": steps, "cfg": cfg, "sampler_name": sampler_name, "scheduler": scheduler, "denoise": denoise, "model": ["4", 0], "positive": ["3", 0], "negative": ["3", 1], "latent_image": ["3", 2]}, "class_type": "KSampler"},
  838. "6": {"inputs": {"samples": ["5", 0], "vae": ["1", 2]}, "class_type": "VAEDecode"},
  839. "7": {"inputs": {"filename_prefix": "ComfyUI", "fps": fps2, "lossless": False, "quality": quality, "method": method, "images": ["6", 0]}, "class_type": "SaveAnimatedWEBP"}
  840. }
  841. return Function.gen_image(workflow, 1, 1, progress)[1]
  842. def blocks():
  843. with gr.Row():
  844. with gr.Column():
  845. SVD.input_image = gr.Image(value=None, type="pil")
  846. with gr.Tab(label="基础设置"):
  847. with gr.Row():
  848. width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width | 图像宽度", value=512)
  849. video_frames = gr.Slider(minimum=1, maximum=25, step=1, label="Video frames | 视频帧", value=25)
  850. with gr.Row():
  851. height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height | 图像高度", value=512)
  852. fps = gr.Slider(minimum=1, maximum=30, step=1, label="FPS | 帧率", value=6)
  853. with gr.Row():
  854. with gr.Column():
  855. augmentation_level = gr.Slider(minimum=0, maximum=1, step=0.01, label="Augmentation level | 增强级别", value=0)
  856. motion_bucket_id = gr.Slider(minimum=1, maximum=256, step=1, label="Motion bucket id | 运动参数", value=127)
  857. with gr.Column():
  858. min_cfg = gr.Slider(minimum=0, maximum=20, step=0.5, label="Min CFG | 最小CFG权重", value=1)
  859. with gr.Tab(label="采样设置"):
  860. with gr.Row():
  861. sampler_name = gr.Dropdown(Choices.sampler, label="Sampling method | 采样方法", value=Choices.sampler[12])
  862. scheduler = gr.Dropdown(Choices.scheduler, label="Schedule type | 采样计划表类型", value=Choices.scheduler[1])
  863. with gr.Row():
  864. denoise = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
  865. steps = gr.Slider(minimum=10, maximum=30, step=1, label="Sampling steps | 采样次数", value=20)
  866. with gr.Row():
  867. cfg = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=2.5)
  868. seed = gr.Slider(minimum=-1, maximum=18446744073709550000, step=1, label="Seed | 种子数", value=-1)
  869. with gr.Tab(label="输出设置"):
  870. with gr.Row():
  871. method = gr.Dropdown(["default", "fastest", "slowest"], label="Method | 输出方法", value="default")
  872. lossless = gr.Dropdown(["true", "false"], label="Lossless | 无损压缩", value="false")
  873. with gr.Row():
  874. quality = gr.Slider(minimum=70, maximum=100, step=1, label="Quality | 输出质量", value=85)
  875. fps2 = gr.Slider(minimum=1, maximum=30, step=1, label="FPS | 帧率", value=10)
  876. with gr.Column():
  877. with gr.Row():
  878. btn = gr.Button("Generate | 生成", elem_id="button")
  879. btn2 = gr.Button("Interrupt | 终止")
  880. output = gr.Image(height=600)
  881. btn.click(fn=SVD.generate, inputs=[SVD.input_image, width, height, video_frames, motion_bucket_id, fps, augmentation_level, min_cfg, seed, steps, cfg, sampler_name, scheduler, denoise, fps2, lossless, quality, method], outputs=[output])
  882. btn2.click(fn=Function.post_interrupt, inputs=None, outputs=None)
  883. class Extras:
  884. def generate(initialized, input_image, progress=gr.Progress()):
  885. module = "Extras"
  886. if input_image is None:
  887. return
  888. else:
  889. input_image = Function.upload_image(input_image)
  890. workflow = {}
  891. node_id = 1
  892. workflow[str(node_id)] = {"inputs": {"image": input_image, "upload": "image"}, "class_type": "LoadImage"}
  893. image_port = [str(node_id), 0]
  894. if initialized is True:
  895. if module not in Upscale.cache and module not in UpscaleWithModel.cache:
  896. return
  897. if module in Upscale.cache:
  898. workflow, node_id, image_port = Upscale.add_node(module, workflow, node_id, image_port)
  899. if module in UpscaleWithModel.cache:
  900. workflow, node_id, image_port = UpscaleWithModel.add_node(module, workflow, node_id, image_port)
  901. else:
  902. return
  903. node_id += 1
  904. workflow[str(node_id)] = {"inputs": {"filename_prefix": "ComfyUI", "images": image_port}, "class_type": "SaveImage"}
  905. output = Function.gen_image(workflow, 1, 1, progress)[0]
  906. if output is not None:
  907. output = output[0]
  908. return output
  909. def blocks():
  910. with gr.Row():
  911. with gr.Column():
  912. Extras.input_image = gr.Image(value=None, type="pil")
  913. with gr.Row():
  914. with gr.Tab(label="算术放大"):
  915. Upscale.blocks("Extras")
  916. with gr.Row():
  917. with gr.Tab(label="超分放大"):
  918. UpscaleWithModel.blocks("Extras")
  919. gr.HTML("注意:同时启用两种放大模式将先执行算术放大,再执行超分放大,最终放大倍数为二者放大倍数的乘积!")
  920. with gr.Column():
  921. with gr.Row():
  922. btn = gr.Button("Generate | 生成", elem_id="button")
  923. btn2 = gr.Button("Interrupt | 终止")
  924. output = gr.Image(height=600)
  925. btn.click(fn=Extras.generate, inputs=[Initial.initialized, Extras.input_image], outputs=[output])
  926. btn2.click(fn=Function.post_interrupt, inputs=None, outputs=None)
  927. class Info:
  928. def generate(image_info, progress=gr.Progress()):
  929. if not image_info or image_info is None or image_info == "仅支持API工作流!!!" or "Version:" in image_info or image_info == "None":
  930. return
  931. workflow = json.loads(image_info)
  932. return Function.gen_image(workflow, 1, 1, progress)[0]
  933. def order_workflow(workflow):
  934. if workflow is None:
  935. return gr.update(visible=False, value=None)
  936. workflow = json.loads(workflow)
  937. if "last_node_id" in workflow:
  938. return gr.update(show_label=False, visible=True, value="仅支持API工作流!!!", lines=1)
  939. workflow = Function.order_workflow(workflow)
  940. lines = len(workflow) + 5
  941. workflow_string = "{"
  942. for node in workflow:
  943. workflow_string = workflow_string + "\n" + f'"{node}": {workflow[node]},'
  944. workflow_string = workflow_string + "\n}"
  945. workflow_string = workflow_string.replace(",\n}", "\n}")
  946. workflow_string = workflow_string.replace("'", '"')
  947. return gr.update(label="Ordered workflow_api", show_label=True, visible=True, value=workflow_string, lines=lines)
  948. def get_image_info(image_pil):
  949. if image_pil is None:
  950. return gr.update(visible=False, value=None)
  951. else:
  952. image_info = Function.get_image_info(image_pil)
  953. if image_info == "None":
  954. return gr.update(visible=False, value=None)
  955. if "Version:" in image_info:
  956. return gr.update(label="Image info", show_label=True, visible=True, value=image_info, lines=3)
  957. return Info.order_workflow(image_info)
  958. def hide_another_input(this_input):
  959. if this_input is None:
  960. return gr.update(visible=True)
  961. return gr.update(visible=False)
  962. def blocks():
  963. with gr.Row():
  964. with gr.Column():
  965. Info.input_image = gr.Image(value=None, type="pil")
  966. workflow = gr.File(label="workflow_api.json", file_types=[".json"], type="binary")
  967. image_info = gr.Textbox(visible=False)
  968. with gr.Column():
  969. with gr.Row():
  970. btn = gr.Button("Generate | 生成", elem_id="button")
  971. btn2 = gr.Button("Interrupt | 终止")
  972. output = gr.Gallery(preview=True, height=600)
  973. btn.click(fn=Info.generate, inputs=[image_info], outputs=[output])
  974. btn2.click(fn=Function.post_interrupt, inputs=None, outputs=None)
  975. Info.input_image.change(fn=Info.hide_another_input, inputs=[Info.input_image], outputs=[workflow])
  976. Info.input_image.change(fn=Info.get_image_info, inputs=[Info.input_image], outputs=[image_info])
  977. workflow.change(fn=Info.hide_another_input, inputs=[workflow], outputs=[Info.input_image])
  978. workflow.change(fn=Info.order_workflow, inputs=[workflow], outputs=[image_info])
  979. with gr.Blocks(css="#button {background: #FFE1C0; color: #FF453A} .block.padded:not(.gradio-accordion) {padding: 0 !important;} div.form {border-width: 0; box-shadow: none; background: white; gap: 1.15em;}") as demo:
  980. Initial.initialized = gr.Checkbox(value=False, visible=False)
  981. with gr.Tab(label="Stable Diffusion"): SD.blocks()
  982. if SC.enable is True:
  983. with gr.Tab(label="Stable Cascade"): SC.blocks()
  984. if SVD.enable is True:
  985. with gr.Tab(label="Stable Video Diffusion"): SVD.blocks()
  986. with gr.Tab(label="Extras"): Extras.blocks()
  987. with gr.Tab(label="Info"): Info.blocks()
  988. SD.send_to_sd.click(fn=Function.send_to, inputs=[SD.data, SD.index], outputs=[SD.input_image])
  989. if SC.enable is True:
  990. SD.send_to_sc.click(fn=Function.send_to, inputs=[SD.data, SD.index], outputs=[SC.input_image])
  991. if SVD.enable is True:
  992. SD.send_to_svd.click(fn=Function.send_to, inputs=[SD.data, SD.index], outputs=[SVD.input_image])
  993. SD.send_to_extras.click(fn=Function.send_to, inputs=[SD.data, SD.index], outputs=[Extras.input_image])
  994. SD.send_to_info.click(fn=Function.send_to, inputs=[SD.data, SD.index], outputs=[Info.input_image])
  995. if SC.enable is True:
  996. SC.send_to_sd.click(fn=Function.send_to, inputs=[SC.data, SC.index], outputs=[SD.input_image])
  997. SC.send_to_sc.click(fn=Function.send_to, inputs=[SC.data, SC.index], outputs=[SC.input_image])
  998. if SVD.enable is True:
  999. SC.send_to_svd.click(fn=Function.send_to, inputs=[SC.data, SC.index], outputs=[SVD.input_image])
  1000. SC.send_to_extras.click(fn=Function.send_to, inputs=[SC.data, SC.index], outputs=[Extras.input_image])
  1001. SC.send_to_info.click(fn=Function.send_to, inputs=[SC.data, SC.index], outputs=[Info.input_image])
  1002. demo.queue(concurrency_count=100).launch(inbrowser=True, inline=False)

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号