当前位置:   article > 正文

stable_diffusion代码运行过程_stable diffusion expected reduction dim to be spec

stable diffusion expected reduction dim to be specified for input.numel() ==

加载模型以及参数

加载参数

首先在Main函数的最开始,新建argparse对象parser,向parser中输入参数以及模型信息,再将这些信息转化为opt

	arser = argparse.ArgumentParser()
    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )
    opt = parser.parse_args()
    config = OmegaConf.load(f"{opt.config}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

下面是debug后config的值:

{'model': {'base_learning_rate': 0.0001, 
'target': 'ldm.models.diffusion.ddpm.LatentDiffusion',
 'params': {'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'log_every_t': 200, 'timesteps': 1000, 'first_stage_key': 'jpg', 'cond_stage_key': 'txt', 'image_size': 64, 'channels': 4, 'cond_stage_trainable': False, 'conditioning_key': 'crossattn', 'monitor': 'val/loss_simple_ema', 'scale_factor': 0.18215, 'use_ema': False, 
 'personalization_config': {'target': 'ldm.modules.embedding_manager.EmbeddingManager', 'params': {'placeholder_strings': ['*'], 'initializer_words': ['sculpture'], 
 'per_image_tokens': False, 'num_vectors_per_token': 1, 'progressive_words': False}}, 
 'unet_config': {'target': 'ldm.modules.diffusionmodules.openaimodel.UNetModel', 
 'params': {'image_size': 32, 'in_channels': 4, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_heads': 8, 'use_spatial_transformer': True, 'transformer_depth': 1, 'context_dim': 768, 'use_checkpoint': True, 'legacy': False}}, 
 'first_stage_config': {'target': 'ldm.models.autoencoder.AutoencoderKL',
  'params': {'embed_dim': 4, 'monitor': 'val/rec_loss', 'ddconfig': {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}, '
  lossconfig': {'target': 'torch.nn.Identity'}}}, 
  'cond_stage_config': {'target': 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'}}}}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

我们可以看到,config主要记录了LatentDiffusion的超参数以及模型参数,其中具体包括unet参数,first_stage参数,cond_stage参数,这三者可以认为是LatentDiffusion三个阶段,而这些参数都是parser从yaml文件中(configs/stable-diffusion/v1-inference.yaml)读取的:

model:
  base_learning_rate: 1.0e-04
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    image_size: 64
    channels: 4
    cond_stage_trainable: false   # Note: different from the one we trained before
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False

    personalization_config:
      target: ldm.modules.embedding_manager.EmbeddingManager
      params:
        placeholder_strings: ["*"]
        initializer_words: ["sculpture"]
        per_image_tokens: false
        num_vectors_per_token: 1
        progressive_words: False
        
    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

加载模型

之后使用下面的代码加载模型,其中调用关系复杂:

model = load_model_from_config(config, f"{opt.ckpt}")
  • 1

首先会调用load_model_from_config函数,该函数接受config和ckpt,从ckpt文件加载模型的状态字典,并将其加载到根据config文件创建的模型中,还打印了可能存在的缺失或意外的键

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    #从ckpt文件加载一个字典,其中包含模型的状态字典state_dict
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    #从配置文件config中实例化一个模型
    model = instantiate_from_config(config.model)
    #使用load_state_dict()方法将模型状态字典sd加载到模型中
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

其中instantiate_from_config函数的具体作用是检查config字典中是否存在名为"target"的键,如果不存在,将检查config是否等于’is_first_stage’或’is_unconditional’,如果"target"键存在于config中,将使用get_obj_from_str()函数根据config[“target”]的值实例化一个对象:

def instantiate_from_config(config, **kwargs):
    if not "target" in config:
        if config == '__is_first_stage__':
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

其中涉及get_obj_from_str函数,它根据给定的字符串string实例化一个对象,在debug过程中,string输入为’ldm.models.diffusion.ddpm.LatentDiffusion’,也就是实例化ldm:

def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    #获取实例属性
    return getattr(importlib.import_module(module, package=None), cls)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

由于string输入为’ldm.models.diffusion.ddpm.LatentDiffusion’,module就是’ldm.models.diffusion.ddpm’,就是cls就是LatentDiffusion,也就是说要去’ldm.models.diffusion.ddpm这个类去使用getattr函数取出LatentDiffusion的类属性。那么接下来就是加载ddpm代码:
进入DDPM类后,先进行各种初始化,在其中self.model = DiffusionWrapper(unet_config, conditioning_key)这段代码调用DiffusionWrapper类初始化模型:

传入的参数是unet_config,和conditioning_key

class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        #根据传入的unet参数进行实例化
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        else:
            raise NotImplementedError()

        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

进入openaimodel.py文件中的unet类去实例化unet对象,其中unet就是先计算time_embed,再下采样(由ResBlock、AttentionBlock和TimestepEmbedSequential组成),中间层,上采样
接下来ldm函数中的super().__init__(conditioning_key=conditioning_key, *args, **kwargs)会跳转到ddpm类,通过register_schedule计算所需参数

接下来就是处理第一阶段模型,使用self.instantiate_first_stage(first_stage_config)初始化模型参数,其中方法和上述差不多,在进入autoencoder.py中的AutoencoderKL类后,首先初始化encoder和decoder,之后加载必要参数:

    def instantiate_first_stage(self, config):
    	#根据config字典来实例化一个模型(model)
        model = instantiate_from_config(config)
        #将实例化后的模型赋值给self.first_stage_model
        self.first_stage_model = model.eval()
        self.first_stage_model.train = disabled_train
        #禁用参数的梯度计算
        for param in self.first_stage_model.parameters():
            param.requires_grad = False
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

在之后进行初始化cond_stage模型,还是调用self.instantiate_cond_stage(cond_stage_config),这次到get_obj_from_str方法时的string=‘ldm.modules.encoders.modules.FrozenCLIPEmbedder’,也就是要去’FrozenCLIPEmbedder’中提取属性
在FrozenCLIPEmbedder类中我们可以设置预训练的参数,

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
    def __init__(self, version="clip-vit-large-patch14", device="cuda", max_length=77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

进行到这里以及完成了ldm模型的加载,下面进行加载采样器:

sampler = DDIMSampler(model)

class DDIMSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

接下来将空提示词加载为条件向量:

if opt.scale != 1.0:
	uc = model.get_learned_conditioning(batch_size * [""])
  • 1
  • 2

这条命令执行会调用以下函数:

	#其中c=['', '', '']
    def get_learned_conditioning(self, c):
        if self.cond_stage_forward is None:
        	#有encode
            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
            	#将c encode一下
                c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager)
                if isinstance(c, DiagonalGaussianDistribution):
                    c = c.mode()
            else:
                c = self.cond_stage_model(c)
        else:
            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
        return c
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

接下来对于encode会调用

    def encode(self, text, **kwargs):
        return self(text, **kwargs)
  • 1
  • 2

转到clip的前向传播函数:

    def forward(self, text, **kwargs):
    	#按照self.max_length的大小对['', '', '']进行编码
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        #取出token
        tokens = batch_encoding["input_ids"].to(self.device)
        #调用transformer的前向函数        
        z = self.transformer(input_ids=tokens, **kwargs)

        return z
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/628853
推荐阅读
相关标签
  

闽ICP备14008679号