当前位置:   article > 正文

深入浅出 diffusion(2):pytorch 实现 diffusion 加噪过程_diffusion-pytorch

diffusion-pytorch

         我在上篇博客深入浅出 diffusion(1):白话 diffusion 原理(无公式)中介绍了 diffusion 的一些基本原理,其中谈到了 diffusion 的加噪过程,本文用pytorch 实现下到底是怎么加噪的。

  1. import torch
  2. import math
  3. import numpy as np
  4. from PIL import Image
  5. import requests
  6. import matplotlib.pyplot as plot
  7. import cv2
  8. def linear_beta_schedule(timesteps):
  9. """
  10. linear schedule, proposed in original ddpm paper
  11. """
  12. scale = 1000 / timesteps
  13. beta_start = scale * 0.0001
  14. beta_end = scale * 0.02
  15. return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
  16. def cosine_beta_schedule(timesteps, s = 0.008):
  17. """
  18. cosine schedule
  19. as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
  20. """
  21. steps = timesteps + 1
  22. t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
  23. alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
  24. alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
  25. betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
  26. return torch.clip(betas, 0, 0.999)
  27. # 时间步(timestep)定义为1000
  28. timesteps = 1000
  29. # 定义Beta Schedule, 选择线性版本,同DDPM原文一致,当然也可以换成cosine_beta_schedule
  30. betas = linear_beta_schedule(timesteps=timesteps)
  31. # 根据beta定义alpha
  32. alphas = 1. - betas
  33. alphas_cumprod = torch.cumprod(alphas, axis=0)
  34. sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
  35. # 计算前向过程 diffusion q(x_t | x_{t-1}) 中所需的
  36. sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
  37. sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
  38. def extract(a, t, x_shape):
  39. batch_size = t.shape[0]
  40. out = a.gather(-1, t.cpu())
  41. return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
  42. # 前向加噪过程: forward diffusion process
  43. def q_sample(x_start, t, noise=None):
  44. if noise is None:
  45. noise = torch.randn_like(x_start)
  46. cv2.imwrite('noise.png', noise.numpy()*255)
  47. sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
  48. sqrt_one_minus_alphas_cumprod_t = extract(
  49. sqrt_one_minus_alphas_cumprod, t, x_start.shape
  50. )
  51. print('sqrt_alphas_cumprod_t :', sqrt_alphas_cumprod_t)
  52. print('sqrt_one_minus_alphas_cumprod_t :', sqrt_one_minus_alphas_cumprod_t)
  53. return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
  54. # 图像后处理
  55. def get_noisy_image(x_start, t):
  56. # add noise
  57. x_noisy = q_sample(x_start, t=t)
  58. # turn back into PIL image
  59. noisy_image = x_noisy.squeeze().numpy()
  60. return noisy_image
  61. ...
  62. # 展示图像, t=0, 50, 100, 500的效果
  63. x_start = cv2.imread('img.png') / 255.0
  64. x_start = torch.tensor(x_start, dtype=torch.float)
  65. cv2.imwrite('img_0.png', get_noisy_image(x_start, torch.tensor([0])) * 255.0)
  66. cv2.imwrite('img_50.png', get_noisy_image(x_start, torch.tensor([50])) * 255.0)
  67. cv2.imwrite('img_100.png', get_noisy_image(x_start, torch.tensor([100])) * 255.0)
  68. cv2.imwrite('img_500.png', get_noisy_image(x_start, torch.tensor([500])) * 255.0)
  69. cv2.imwrite('img_999.png', get_noisy_image(x_start, torch.tensor([999])) * 255.0)
  70. sqrt_alphas_cumprod_t : tensor([[[0.9999]]], dtype=torch.float64)
  71. sqrt_one_minus_alphas_cumprod_t : tensor([[[0.0100]]], dtype=torch.float64)
  72. sqrt_alphas_cumprod_t : tensor([[[0.9849]]], dtype=torch.float64)
  73. sqrt_one_minus_alphas_cumprod_t : tensor([[[0.1733]]], dtype=torch.float64)
  74. sqrt_alphas_cumprod_t : tensor([[[0.9461]]], dtype=torch.float64)
  75. sqrt_one_minus_alphas_cumprod_t : tensor([[[0.3238]]], dtype=torch.float64)
  76. sqrt_alphas_cumprod_t : tensor([[[0.2789]]], dtype=torch.float64)
  77. sqrt_one_minus_alphas_cumprod_t : tensor([[[0.9603]]], dtype=torch.float64)
  78. sqrt_alphas_cumprod_t : tensor([[[0.0064]]], dtype=torch.float64)
  79. sqrt_one_minus_alphas_cumprod_t : tensor([[[1.0000]]], dtype=torch.float64)

        以下分别为原图,t = 0, 50, 100, 500, 999 的结果。

        可见,随着 t 的加大,原图对应的比例系数减小,噪声的强度系数加大,t = 500的时候,隐约可见人脸轮廓,t = 999 的时候,人脸彻底淹没在噪声里面了。

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/1004579
推荐阅读
相关标签
  

闽ICP备14008679号