当前位置:   article > 正文

Load LoRA XXX.safetensors to diffusers (StableDiffusionPipeline/StableDiffusionControlNetPipeline)_safetensors转diffusers

safetensors转diffusers
  1. def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
  2. LORA_PREFIX_UNET = "lora_unet"
  3. LORA_PREFIX_TEXT_ENCODER = "lora_te"
  4. # load LoRA weight from .safetensors
  5. state_dict = load_file(checkpoint_path, device=device)
  6. updates = defaultdict(dict)
  7. for key, value in state_dict.items():
  8. # it is suggested to print out the key, it usually will be something like below
  9. # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
  10. layer, elem = key.split('.', 1)
  11. updates[layer][elem] = value
  12. # directly update weight in diffusers model
  13. for layer, elems in updates.items():
  14. if "text" in layer:
  15. layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
  16. curr_layer = pipeline.text_encoder
  17. else:
  18. layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
  19. curr_layer = pipeline.unet
  20. # find the target layer
  21. temp_name = layer_infos.pop(0)
  22. while len(layer_infos) > -1:
  23. try:
  24. curr_layer = curr_layer.__getattr__(temp_name)
  25. if len(layer_infos) > 0:
  26. temp_name = layer_infos.pop(0)
  27. elif len(layer_infos) == 0:
  28. break
  29. except Exception:
  30. if len(temp_name) > 0:
  31. temp_name += "_" + layer_infos.pop(0)
  32. else:
  33. temp_name = layer_infos.pop(0)
  34. # get elements for this layer
  35. weight_up = elems['lora_up.weight'].to(dtype)
  36. weight_down = elems['lora_down.weight'].to(dtype)
  37. alpha = elems['alpha']
  38. if alpha:
  39. alpha = alpha.item() / weight_up.shape[1]
  40. else:
  41. alpha = 1.0
  42. curr_layer.weight.data = curr_layer.weight.data.to(device)
  43. # update weight
  44. if len(weight_up.shape) == 4:
  45. curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
  46. else:
  47. curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
  48. return pipeline

Example usage:

pipe = load_lora_weights(pipe, lora_path, 1.0, 'cuda', torch.float16)

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/748639
推荐阅读
相关标签
  

闽ICP备14008679号