当前位置:   article > 正文

【AutoencoderKL】基于stable-diffusion-v1.4的vae对图像重构

【AutoencoderKL】基于stable-diffusion-v1.4的vae对图像重构

模型地址:https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main/vae
主要参考:Using-Stable-Diffusion-VAE-to-encode-satellite-images
在这里插入图片描述

sd1.4 vae

下载到本地

from diffusers import AutoencoderKL
from PIL import Image
import  torch
import torchvision.transforms as T

#  ./huggingface/stable-diffusion-v1-4/vae 切换为任意本地路径
vae = AutoencoderKL.from_pretrained("./huggingface/stable-diffusion-v1-4/vae",variant='fp16')
# c:\Users\zeng\Downloads\vae_config.json

def encode_img(input_img):
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
        # Transform the image to a tensor and normalize it
    transform = T.Compose([
        # T.Resize((256, 256)),
        T.ToTensor()
    ])
    input_img = transform(input_img)
    if len(input_img.shape)<4:
        input_img = input_img.unsqueeze(0)
    with torch.no_grad():
        latent = vae.encode(input_img*2 - 1) # Note scaling
    return 0.18215 * latent.latent_dist.sample()



def decode_img(latents):
    # bath of latents -> list of images
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu()
    # image = T.Resize(original_size)(image.squeeze())
    return T.ToPILImage()(image.squeeze())

if __name__ == '__main__':
    # Load an example image
    input_img = Image.open("huge.jpg")
    original_size = input_img.size
    print('original_size',original_size)

    # Encode and decode the image
    latents = encode_img(input_img)
    reconstructed_img = decode_img(latents)

    # Save the reconstructed image
    reconstructed_img.save("reconstructed_example2.jpg")
    # Concatenate the original and reconstructed images
    concatenated_img = Image.new('RGB', (original_size[0] * 2, original_size[1]))
    concatenated_img.paste(input_img, (0, 0))
    concatenated_img.paste(reconstructed_img, (original_size[0], 0))
    # Save the concatenated image
    concatenated_img.save("concatenated_example2.jpg")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/801846
推荐阅读
相关标签
  

闽ICP备14008679号