当前位置:   article > 正文

Stable Diffusion | Gradio界面设计及API调用_gradio api

gradio api

Stability AI 2024年2月发布了 Stable Cascade 模型,但由于该模型较大(fp32格式的 Stage_A + Stage_B + Stage_C 模型超过20GB,ComfyUI 专用 Stage_B + Stage_C 模型也要14GB),对显卡要求较高,限制了大家体验 Stable Cascade 模型。

本文主要介绍如何在百度 AI Studio 平台上通过 Gradio 交互式界面运行 Stable Cascade 模型进行文生图。

项目链接:Stable Diffusion ComfyUI 在线体验 。

1. 项目体验

由于项目更新,以下示例图片与新版本有一定差异,但操作方法一样

1.1 创建项目副本

打开项目链接 Stable Diffusion | Gradio界面设计及API调用

百度 AI Studio 平台需要登陆百度账号使用

点击右上角红色方框内的 fork 按钮

创建项目副本

运行项目

1.2 获取免费算力卡

百度 AI Studio 平台每日免费算力卡需要运行任意项目后发放,fork 项目后点击启动环境即可获取8点免费算力卡

1.3 启动GPU环境

选择 V100 32GB 运行环境并点击确定

进入环境

1.4 部署ComfyUI

双击打开左侧文件浏览器中的 1 部署ComfyUI.ipynb

点击红框内按钮解压 ComfyUI 部署包

解压完成(解压仅需1-2分钟):

1.5 启动ComfyUI-API

双击打开左侧文件浏览器中的 2 启动ComfyUI-API.ipynb 并点击红框内按钮启动 ComfyUI-API

ComfyUI-API 已启动:

1.6 启动Gradio界面

双击打开左侧文件浏览器中的 3 启动Gradio界面.gradio.py,等待 Gradio 界面加载完成后点击红框内按钮在新的浏览器页面打开 Gradio 界面

浏览器新页面中的 Gradio 界面:

点击右上角红框内按钮开始文生图,首次运行因为要加载约14GB的 Stable Cascade 模型到显存,第一张图片大约需要2分钟才能生成,后续生成一张 1024*1024 图片大约需要30秒。

文生图示例:

1.7 停止项目

由于 AI Studio 平台每天的免费算力卡只有8点,运行 V100 32GB 环境每小时消耗3点算力卡,不生图时应尽快关闭项目

依次关闭四个选项卡

无需保存修改

项目首页点击停止按钮

2. Gradio界面设计及API调用源码

源码随项目不断更新,最新版本见项目内部(Stable Diffusion | Gradio界面设计及API调用)。

  1. import gradio as gr
  2. import io
  3. import json
  4. import os
  5. import random
  6. import requests
  7. import urllib.parse
  8. import uuid
  9. from PIL import Image
  10. import sys
  11. sys.path.append("/home/aistudio/work/ComfyUI/venv/lib/python3.10/site-packages")
  12. import websocket
  13. # 定义ComfyUI服务器地址
  14. server_address = "127.0.0.1:8188"
  15. # 定义SD模型所在文件夹路径,默认sd_models_path为该py文件所在路径+"/data"
  16. sd_models_path = os.getcwd() + "/data"
  17. # 定义默认正向提示词
  18. default_prompt = "evening sunset scenery blue sky nature, glass bottle with a fizzy ice cold freezing rainbow liquid in it"
  19. # 定义默认负向提示词
  20. default_negative_prompt = "text, watermark"
  21. # 定义可选择采样器和采样计划表类型(数组格式)
  22. samplers = ["euler", "euler_ancestral", "heun", "heunpp2", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
  23. schedulers = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
  24. # 定义 获取指定SD模型(sd_model)在sd_models_path中的路径 的函数
  25. def get_model_path(sd_model):
  26. # 当指定SD模型在sd_models_path根目录中时,模型路径与模型名称相同
  27. if os.path.exists(os.path.join(sd_models_path, sd_model)):
  28. sd_model_path = sd_model
  29. # 当指定SD模型在sd_models_path子目录中时,模型路径为"子目录名称/模型名称"
  30. else:
  31. for folder in os.listdir(sd_models_path):
  32. temp_sd_models_path = os.path.join(sd_models_path, folder)
  33. if os.path.exists(os.path.join(temp_sd_models_path, sd_model)):
  34. sd_model_path = os.path.join(folder, sd_model)
  35. return sd_model_path
  36. # 定义客户端ID,用于和服务器建立websocket连接
  37. client_id = str(uuid.uuid4())
  38. # 定义 向服务器提交工作流并获取生成的图片 的函数
  39. def generate_images(workflow):
  40. # 与服务器建立websocket连接
  41. ws = websocket.WebSocket()
  42. ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
  43. data = {"prompt":workflow, "client_id":client_id}
  44. prompt_id = requests.post(url="http://{}/prompt".format(server_address), json=data).json()["prompt_id"]
  45. while True:
  46. wsrecv = ws.recv()
  47. if isinstance(wsrecv, str):
  48. message = json.loads(wsrecv)
  49. if message["type"] == "executing":
  50. data = message["data"]
  51. if data["node"] is None and data["prompt_id"] == prompt_id:
  52. break
  53. else:
  54. continue
  55. history = requests.get(url="http://{}/history/{}".format(server_address, prompt_id)).json()[prompt_id]
  56. for output in history["outputs"]:
  57. for node_id in history["outputs"]:
  58. node_output = history["outputs"][node_id]
  59. if "images" in node_output:
  60. images = []
  61. for image in node_output["images"]:
  62. data = {"filename":image["filename"], "subfolder":image["subfolder"], "type":image["type"]}
  63. url_values = urllib.parse.urlencode(data)
  64. image_data = requests.get("http://{}/view?{}".format(server_address, url_values)).content
  65. image = Image.open(io.BytesIO(image_data))
  66. images.append(image)
  67. return images
  68. # 定义 Stable Cascade文生图 函数
  69. def cascade_txt2img(positive_prompt, negative_prompt, width, height, compression, batch_size, seed_c, steps_c, cfg_c, sampler_name_c, scheduler_c, denoise_c, seed_b, steps_b, cfg_b, sampler_name_b, scheduler_b, denoise_b):
  70. if seed_c == "-1":
  71. seed_c = random.randint(0, 9223372036854775807)
  72. if seed_b == "-1":
  73. seed_b = random.randint(0, 9223372036854775807)
  74. # 定义Stable Cascade txt2img工作流
  75. cascade_txt2img_workflow = {
  76. "1":{"inputs":{"ckpt_name":get_model_path("stable_cascade_stage_c.safetensors")}, "class_type":"CheckpointLoaderSimple"},
  77. "2":{"inputs":{"text":positive_prompt, "clip":["1", 1]}, "class_type":"CLIPTextEncode"},
  78. "3":{"inputs":{"text":negative_prompt, "clip":["1", 1]}, "class_type":"CLIPTextEncode"},
  79. "4":{"inputs":{"width":width, "height":height, "compression":compression, "batch_size":batch_size}, "class_type":"StableCascade_EmptyLatentImage"},
  80. "5":{"inputs":{"seed":seed_c, "steps":steps_c, "cfg":cfg_c, "sampler_name":sampler_name_c, "scheduler":scheduler_c, "denoise":denoise_c, "model":["1", 0], "positive":["2", 0], "negative":["3", 0], "latent_image":["4", 0]}, "class_type":"KSampler"},
  81. "6":{"inputs":{"conditioning":["2", 0], "stage_c":["5", 0]}, "class_type":"StableCascade_StageB_Conditioning"},
  82. "7":{"inputs":{"ckpt_name":get_model_path("stable_cascade_stage_b.safetensors")}, "class_type":"CheckpointLoaderSimple"},
  83. "8":{"inputs":{"seed":seed_b, "steps":steps_b, "cfg":cfg_b, "sampler_name":sampler_name_b, "scheduler":scheduler_b, "denoise":denoise_b, "model":["7", 0], "positive":["6", 0], "negative":["3", 0], "latent_image":["4", 1]}, "class_type":"KSampler"},
  84. "9":{"inputs":{"samples":["8", 0], "vae":["7", 2]}, "class_type":"VAEDecode"},
  85. "10":{"inputs":{"filename_prefix":"Cascade", "images":["9", 0]}, "class_type":"SaveImage"}
  86. }
  87. images = generate_images(cascade_txt2img_workflow)
  88. return images
  89. # Gradio界面设计
  90. with gr.Blocks() as demo:
  91. # 以下模块按行排列
  92. with gr.Row():
  93. # 以下模块按列排列
  94. with gr.Column():
  95. # gr.Textbox()为可输入文本框,label为该模块的标签,value为默认值
  96. positive_prompt = gr.Textbox(label="Positive prompt | 正向提示词", value=default_prompt)
  97. negative_prompt = gr.Textbox(label="Negative prompt | 负向提示词", value=default_negative_prompt)
  98. # gr.Tab()为选项卡模块,label为该模块的标签
  99. with gr.Tab(label="Stage C 采样阶段设置"):
  100. # 以下模块合并为组
  101. with gr.Group():
  102. with gr.Row():
  103. # gr.Dropdown()为可下拉选择框,第一个参数必须为包含下拉选项的数组["..."],label为该模块的标签,value为默认值
  104. sampler_name_c = gr.Dropdown(samplers, label="Sampling method | 采样方法", value=samplers[12])
  105. scheduler_c = gr.Dropdown(schedulers, label="Schedule type | 采样计划表类型", value=schedulers[1])
  106. with gr.Row():
  107. # gr.Slider()为滑块模块,minimum为最小数值,maximum为最大数值,step为最小滑动步长,label为该模块的标签,value为默认值
  108. width = gr.Slider(minimum=512, maximum=2048, step=128, label="Width | 图像宽度", value=1024)
  109. steps_c = gr.Slider(minimum=10, maximum=30, step=1, label="Sampling steps | 采样次数", value=20)
  110. with gr.Row():
  111. height = gr.Slider(minimum=512, maximum=2048, step=128, label="Height | 图像高度", value=1024)
  112. batch_size = gr.Slider(minimum=1, maximum=8, step=1, label="Batch size | 单批次生成图像数", value=1)
  113. with gr.Row():
  114. denoise_c = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
  115. compression = gr.Slider(minimum=8, maximum=42, step=1, label="Compression | 压缩倍率", value=42)
  116. with gr.Row():
  117. cfg_c = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=4.0)
  118. seed_c = gr.Textbox(label="Seed | 种子数(-1表示随机种子数)", value=-1)
  119. with gr.Tab(label="Stage B 采样阶段设置", open=False):
  120. with gr.Row():
  121. sampler_name_b = gr.Dropdown(samplers, label="Sampling method | 采样方法", value=samplers[12])
  122. scheduler_b = gr.Dropdown(schedulers, label="Schedule type | 采样计划表类型", value=schedulers[1])
  123. with gr.Row():
  124. denoise_b = gr.Slider(minimum=0, maximum=1, step=0.05, label="Denoise | 去噪强度", value=1)
  125. steps_b = gr.Slider(minimum=4, maximum=12, step=1, label="Sampling steps | 采样次数", value=10)
  126. with gr.Row():
  127. cfg_b = gr.Slider(minimum=0, maximum=20, step=0.1, label="CFG Scale | CFG权重", value=1.1)
  128. seed_b = gr.Textbox(label="Seed | 种子数(-1表示随机种子数)", value=-1)
  129. with gr.Column():
  130. # gr.Button()为按键模块,仅显示一个按键,需搭配.Click()定义该按键功能
  131. btn = gr.Button("Generate | 生成")
  132. # gr.Gallery()为画廊模块,可以显示一张或多张生成图片,设置preview=True可开启预览模式,height参数为画廊模块的高度,单位为像素
  133. gallery = gr.Gallery(preview=True, height=640)
  134. # .Click()模块,可用于定义按键功能,fn为按下该按键后调用的函数,inputs为该函数的输入值,outputs为Gradio输出内容,outputs内模块会依次读取函数返回的值,注意顺序和数值类型!!!
  135. btn.click(fn=cascade_txt2img, inputs=[positive_prompt, negative_prompt, width, height, compression, batch_size, seed_c, steps_c, cfg_c, sampler_name_c, scheduler_c, denoise_c, seed_b, steps_b, cfg_b, sampler_name_b, scheduler_b, denoise_b], outputs=[gallery])
  136. # .queue()可指定队列相关参数,此处status_update_rate=30为每30秒给客户端发送队列完成状态,用于防止Gradio超时60秒后自动报错并退出,此处inbrowser=True可在Gradio启动后自动打开网页,受AI Studio平台限制,该参数无法打开网页~
  137. demo.queue(status_update_rate=30).launch(inbrowser=True)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/881472
推荐阅读
相关标签
  

闽ICP备14008679号