当前位置:   article > 正文

AIGC专栏13——ComfyUI 插件编写细节解析-以EasyAnimateV3为例_comfyui easyanimate

comfyui easyanimate

学习前言

最近在给EasyAnimateV3写ComfyUI的工作流,以方便别人测试。学习了ComfyUI的基本操作,也看了一下别人是怎么写的,自己也折腾了一下。
在这里插入图片描述

什么是ComfyUI

人工智能艺术创作的领域里,Stable Diffusion 凭借其开放源代码的特性,吸引了众多开发者与艺术家的目光,并且因为强大的社区支持而展现出强大的影响力。

Stable Diffusion 的两大主流创作平台分别是 Stable Diffusion WebUI 与 ComfyUI。前者以其即装即用的便捷性、丰富的基础功能及广泛的社区插件支持,成为了新手的理想选择。而ComfyUI则更注重算法细节,主要特点是 结合了 工作流和节点,达成更高级别的自动化,使得创作流程更易于复现与传播。不过,这也意味着较高的学习曲线,要求用户对 Stable Diffusion 及其扩展功能有深入理解,动手实践能力亦需达到一定水准。

着眼于未来的工业化生产与效率提升,ComfyUI 显现出更为广阔的潜力与前景。
在这里插入图片描述

相关地址汇总

ComfyUI

https://github.com/comfyanonymous/ComfyUI

EasyAnimateV3

https://github.com/aigc-apps/EasyAnimate
感谢大家的关注。

节点例子

插件其实本质上是一个个节点,在ComfyUI中,一个节点对应一个类,本文先以ExampleNode为例进行解析,如下所示:

ExampleNode有几个方法,必要是的__init__INPUT_TYPES和一个FUNCTION 对应的函数。

  • __init__是python类的通用方法,非必须有内容,所以我们可以先pass。
  • INPUT_TYPES是ComfyUI必须要有的一个方法,需要放在classmethod装饰器下面,在INPUT_TYPES我们可以指定requiredoptional代表参数是必要的或者可选的。 required和optional均为一个字典,字典的key是这个参数的名称,key对应的value是这个参数的类别,这个类别既可以自定义,也可以用一些通用的,常见通用类别有IMAGEINTFLOAT等,具体的设置方法如下的代码所示。
  • FUNCTION对应的函数代表该节点执行的函数,这个函数一般是自定义的,如下所示的example_func,做的工作就是将所有的像素点除以2后+1。

除此之外还需要设置几个变量名:

  • FUNCTION = "example_func"对应了FUNCTION的函数名
  • RETURN_TYPES = ("IMAGE",)代表这个节点返回的内容类别,常见通用类别有IMAGEINTFLOAT
  • RETURN_NAMES = ("image_output_demo",)代表在UI上显示的节点返回的名称
  • CATEGORY = "Example"代表这个节点的种类。
class ExampleNode:def __init__(self):
        pass
    
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image_demo": ("IMAGE",),
                "int_demo": ("INT", {
                    "default": 0, 	# 默认值
                    "min": 0, 		# 最小值
                    "max": 4096, 	# 最大值
                    "step": 64, 	# 在INT类型中,每次改变的最小步长
                    "display": "number" # 该参数在ComfyUI上的外观: "number","slider"
                }),
                "float_demo": ("FLOAT", {
                    "default": 1.0,
                    "min": 0.0,
                    "max": 10.0,
                    "step": 0.01,
                    "round": 0.001, 	# 表示精度,不设置的话为步长值。可以设置为 False 以禁用。
                    "display": "number"}),
                "string_demo": ("STRING", {
                    "multiline": False, # 只显示一行字符串,多余部分会在 UI上隐藏
                    "default": "string_demo"
                }),
            },
            "optional":{
                "optional_image": ("IMAGE",),
            },
        }
    FUNCTION = "example_func"
    
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image_output_demo",)
    CATEGORY = "Example"

    def example_func(self, image_demo, int_demo, float_demo, string_demo, optional_image):
        image = image / 2.0 + 1
        return (image,)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

复杂例子-以EasyAnimateV3为例

由于正常的Stable Diffusion的库并不会只有一个简单的Example Node类,可能存在复杂的导入情况,在这个基础上我们进行解析。

节点文件

必要库的导入

文件结构如下,为了代码文件的干净,我们将nodes对应的python文件和json工作流放入到自己创建的名为comfyui的文件夹中。
在这里插入图片描述

除去一些pip安装的库,在导入其它EasyAnimate组件时,需要使用相对路径。

# 正常的pip库导入部分
import gc
import os

import torch
import numpy as np
from PIL import Image
from diffusers import (AutoencoderKL, DDIMScheduler,
                       DPMSolverMultistepScheduler,
                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
                       PNDMScheduler)
from einops import rearrange
from omegaconf import OmegaConf
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

import comfy.model_management as mm
import folder_paths
from comfy.utils import ProgressBar, load_torch_file

# 相对目录导入easyanimate组件
from ..easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
from ..easyanimate.models.transformer3d import Transformer3DModel
from ..easyanimate.pipeline.pipeline_easyanimate_inpaint import EasyAnimateInpaintPipeline
from ..easyanimate.utils.utils import get_image_to_video_latent
from ..easyanimate.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

载入模型节点定义

首先定义了INPUT_TYPES,载入模型的时候我们需要指定的必要参数是:

  • model 模型的名称,分别有'EasyAnimateV3-XL-2-InP-512x512'、'EasyAnimateV3-XL-2-InP-768x768'、'EasyAnimateV3-XL-2-InP-960x960'
  • low_gpu_memory_mode 是否开启 低显存模式。
  • config 用于指定EasyAnimateV3的config。
  • precision 模型的精度,EasyAnimateV3应该用bf16。

除此之外还需要设置几个变量名:

  • RETURN_TYPES = ("EASYANIMATESMODEL",) 代表这个节点返回的内容类别,我们自定义了一个EASYANIMATESMODEL类别
  • RETURN_NAMES = ("easyanimate_model",)代表在UI上显示的节点返回的名称
  • FUNCTION = "loadmodel"对应了FUNCTION的函数名
  • CATEGORY = "EasyAnimateWrapper"代表这个节点的种类。

在loadmodel函数中,我们载入了EasyAnimateV3的pipeline,并且通过字典的方式返回回去。

class LoadEasyAnimateModel:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
             	# 模型名称
                "model": (
                    [ 
                        'EasyAnimateV3-XL-2-InP-512x512',
                        'EasyAnimateV3-XL-2-InP-768x768',
                        'EasyAnimateV3-XL-2-InP-960x960'
                    ],
                    {
                        "default": 'EasyAnimateV3-XL-2-InP-768x768',
                    }
                ),
                # 是否开启 低显存模式
                "low_gpu_memory_mode":(
                    [False, True],
                    {
                        "default": False,
                    }
                ),
                # config 名称
                "config": (
                    [
                        "easyanimate_video_slicevae_motion_module_v3.yaml",
                    ],
                    {
                        "default": "easyanimate_video_slicevae_motion_module_v3.yaml",
                    }
                ),
                # 模型精度
                "precision": (
                    ['fp16', 'bf16'],
                    {
                        "default": 'bf16'
                    }
                ),
                
            },
        }
    # 定义了一个新的类别
    RETURN_TYPES = ("EASYANIMATESMODEL",)
    RETURN_NAMES = ("easyanimate_model",)
    # 对应了`FUNCTION`的函数名
    FUNCTION = "loadmodel"
    CATEGORY = "EasyAnimateWrapper"
	
	# 载入模型的函数
    def loadmodel(self, low_gpu_memory_mode, model, precision, config):
        # Init weight_dtype and device
        device          = mm.get_torch_device()
        offload_device  = mm.unet_offload_device()
        weight_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]

        # Init processbar
        pbar = ProgressBar(4)

        # Load config
        config_path = f"{script_directory}/config/{config}"
        config = OmegaConf.load(config_path)

        # Detect model is existing or not 
        model_path = os.path.join(folder_paths.models_dir, "EasyAnimate", model)
      
        if not os.path.exists(model_path):
            if os.path.exists(eas_cache_dir):
                model_path = os.path.join(eas_cache_dir, 'EasyAnimate', model)
            else:
                print(f"Please download easyanimate model to: {model_path}")

        # Load vae
        if OmegaConf.to_container(config['vae_kwargs'])['enable_magvit']:
            Choosen_AutoencoderKL = AutoencoderKLMagvit
        else:
            Choosen_AutoencoderKL = AutoencoderKL
        print("Load Vae.")
        vae = Choosen_AutoencoderKL.from_pretrained(
            model_path, 
            subfolder="vae", 
        ).to(weight_dtype)
        # Update pbar
        pbar.update(1)

        # Load Sampler
        print("Load Sampler.")
        scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler')
        # Update pbar
        pbar.update(1)
        
        # Load Transformer
        print("Load Transformer.")
        transformer = Transformer3DModel.from_pretrained(
            model_path, 
            subfolder= 'transformer', 
            transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs'])
        ).to(weight_dtype).eval()  
        # Update pbar
        pbar.update(1) 

        # Load Transformer
        if transformer.config.in_channels == 12:
            clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
                model_path, subfolder="image_encoder"
            ).to(device, weight_dtype)
            clip_image_processor = CLIPImageProcessor.from_pretrained(
                model_path, subfolder="image_encoder"
            )
        else:
            clip_image_encoder = None
            clip_image_processor = None   
        # Update pbar
        pbar.update(1)

        pipeline = EasyAnimateInpaintPipeline.from_pretrained(
                model_path,
                transformer=transformer,
                scheduler=scheduler,
                vae=vae,
                torch_dtype=weight_dtype,
                clip_image_encoder=clip_image_encoder,
                clip_image_processor=clip_image_processor,
        )
    
        if low_gpu_memory_mode:
            pipeline.enable_sequential_cpu_offload()
        else:
            pipeline.enable_model_cpu_offload()

        easyanimate_model = {
            'pipeline': pipeline, 
            'dtype': weight_dtype,
            'model_path': model_path,
        }
        return (easyanimate_model,)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136

Image to Video节点定义

EasyAnimateV3的重要功能是实现了Image to Video。我们便以这个为例进行Node的解析。

首先定义了INPUT_TYPES,载入模型的时候我们需要指定的必要参数是:

  • easyanimate_model 代表在载入模块 载入的模型,是一个特殊的种类类别EASYANIMATESMODEL。
  • prompt 代表正向提示词。在这里我们自定义了个文本框,对应的种类格式为STRING_PROMPT。
  • negative_prompt 代表负向提示词。在这里我们自定义了个文本框,对应的种类格式为STRING_PROMPT。
  • video_length 视频的长度。
  • base_resolution 基础模型的分辨率,分别有512、768、960。
  • seed种子。
  • steps预测步数。
  • cfg guidance的强度。
  • scheduler采样器名称。

可选参数是:

  • start_img 开始图片;
  • end_img 结尾图片。

除此之外还需要设置几个变量名:

  • RETURN_TYPES = ("IMAGE",) 代表这个节点返回的内容类别,返回一系列帧图片。
  • RETURN_NAMES = ("images",)代表在UI上显示的节点返回的名称
  • FUNCTION = "process"对应了FUNCTION的函数名
  • CATEGORY = "EasyAnimateWrapper"代表这个节点的种类。

在process函数中,我们预处理了输入图片,并且进行了生成,最后返回了一系列帧图片。

class EasyAnimateI2VSampler:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "easyanimate_model": (
                    "EASYANIMATESMODEL", 
                ),
                "prompt": (
                    "STRING_PROMPT",
                ),
                "negative_prompt": (
                    "STRING_PROMPT",
                ),
                "video_length": (
                    "INT", {"default": 72, "min": 8, "max": 144, "step": 8}
                ),
                "base_resolution": (
                    [ 
                        512,
                        768,
                        960,
                    ], {"default": 768}
                ),
                "seed": (
                    "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}
                ),
                "steps": (
                    "INT", {"default": 25, "min": 1, "max": 200, "step": 1}
                ),
                "cfg": (
                    "FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.01}
                ),
                "scheduler": (
                    [ 
                        "Euler",
                        "Euler A",
                        "DPM++",
                        "PNDM",
                        "DDIM",
                    ],
                    {
                        "default": 'Euler'
                    }
                )
            },
            "optional":{
                "start_img": ("IMAGE",),
                "end_img": ("IMAGE",),
            },
        }
    
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES =("images",)
    FUNCTION = "process"
    CATEGORY = "EasyAnimateWrapper"

    def process(self, easyanimate_model, prompt, negative_prompt, video_length, base_resolution, seed, steps, cfg, scheduler, start_img=None, end_img=None):
        device = mm.get_torch_device()
        offload_device = mm.unet_offload_device()

        mm.soft_empty_cache()
        gc.collect()

        start_img = [to_pil(_start_img) for _start_img in start_img] if start_img is not None else None
        end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None
        # Count most suitable height and width
        aspect_ratio_sample_size    = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
        original_width, original_height = start_img[0].size if type(start_img) is list else Image.open(start_img).size
        closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
        height, width = [int(x / 16) * 16 for x in closest_size]
        
        # Get Pipeline
        pipeline = easyanimate_model['pipeline']
        model_path = easyanimate_model['model_path']

        # Load Sampler
        if scheduler == "DPM++":
            noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder= 'scheduler')
        elif scheduler == "Euler":
            noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler')
        elif scheduler == "Euler A":
            noise_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler')
        elif scheduler == "PNDM":
            noise_scheduler = PNDMScheduler.from_pretrained(model_path, subfolder= 'scheduler')
        elif scheduler == "DDIM":
            noise_scheduler = DDIMScheduler.from_pretrained(model_path, subfolder= 'scheduler')
        pipeline.scheduler = noise_scheduler

        generator= torch.Generator(device).manual_seed(seed)

        with torch.no_grad():
            video_length = int(video_length // pipeline.vae.mini_batch_encoder * pipeline.vae.mini_batch_encoder) if video_length != 1 else 1
            input_video, input_video_mask, clip_image = get_image_to_video_latent(start_img, end_img, video_length=video_length, sample_size=(height, width))

            sample = pipeline(
                prompt, 
                video_length = video_length,
                negative_prompt = negative_prompt,
                height      = height,
                width       = width,
                generator   = generator,
                guidance_scale = cfg,
                num_inference_steps = steps,

                video        = input_video,
                mask_video   = input_video_mask,
                clip_image   = clip_image, 
                comfyui_progressbar = True,
            ).videos
            videos = rearrange(sample, "b c t h w -> (b t) h w c")
        return (videos,)   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112

节点名称映射

NODE_CLASS_MAPPINGS 的key代表comfyUI显示的节点名称,value对应了我们定义的类名。
NODE_DISPLAY_NAME_MAPPINGS 的key代表comfyUI显示的节点名称,value对应了该节点的默认名称。

在这里有两个节点未解析TextBox、EasyAnimateT2VSampler,这里两个节点较为简单,希望同学们举一反三:)

NODE_CLASS_MAPPINGS = {
    "LoadEasyAnimateModel": LoadEasyAnimateModel,
    "TextBox": TextBox,
    "EasyAnimateI2VSampler": EasyAnimateI2VSampler,
    "EasyAnimateT2VSampler": EasyAnimateT2VSampler,
}
NODE_DISPLAY_NAME_MAPPINGS = {
    "TextBox": "TextBox",
    "LoadEasyAnimateModel": "Load EasyAnimate Model",
    "EasyAnimateI2VSampler": "EasyAnimate Sampler for Image to Video",
    "EasyAnimateT2VSampler": "EasyAnimate Sampler for Text to Video",
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

init.py文件

在我们仓库的根目录,还需要一个__init__.py文件来导入这些节点:
在这里插入图片描述
具体代码为:

from .comfyui.comfyui_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS

__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
  • 1
  • 2
  • 3

插件导入comfyUI

只要将这个文件移入到comfyUI的custom_nodes文件夹即可。
在这里插入图片描述
然后我们可以在ui界面中找到我们定义的这些节点。
在这里插入图片描述

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号