当前位置:   article > 正文

【深度学习模型】扩散模型(Diffusion Model)基本原理及代码讲解_2d扩散模型原理

2d扩散模型原理
  1. 前言

生成式建模的扩散思想实际上已经在2015年(Sohl-Dickstein等人)提出,然而,直到2019年斯坦福大学(Song等人)、2020年Google Brain(Ho等人)才改进了这个方法,从此引发了生成式模型的新潮流。目前,包括OpenAI的GLIDE和DALL-E 2,海德堡大学的Latent Diffusion和Google Brain的ImageGen,都基于diffusion模型,并可以得到高质量的生成效果。本文以下讲解主要基于DDPM,并适当地增加一些目前有效的改进内容。

  1. 基本原理

扩散模型包括两个步骤:

  1. 固定的(或预设的)前向扩散过程q:该过程会逐渐将高斯噪声添加到图像中,直到最终得到纯噪声。

  1. 训练的反向去噪扩散过程:训练一个神经网络,从纯噪音开始逐渐去噪,直到得到一个真实图像。

前向与后向的步数由下标 t定义,并且有预先定义好的总步数 T(DDPM原文中为1000)。

t=0 时为从数据集中采样得到的一张真实图片, t=T 时近似为一张纯粹的噪声。

2.1 直观理解

为了看懂扩散模型查了很多资料,但是要么就是大量的数学公式,一行行公式推完了还是不知道它想干啥。要么就是高视角,上来就和能量模型,VAE放一块儿对比说共同点和不同点,看完还是云里雾里。然而事实上下面几句话就能把扩散模型说明白了
扩散模型的目的是什么?
学习从纯噪声生成图片的方法
扩散模型是怎么做的?
训练一个U-Net,接受一系列加了噪声的图片,学习预测所加的噪声
前向过程在干啥?
逐步向真实图片添加噪声最终得到一个纯噪声
对于训练集中的每张图片,都能生成一系列的噪声程度不同的加噪图片
在训练时,这些 【不同程度的噪声图片 + 生成它们所用的噪声】 是实际的训练样本
反向过程在干啥?
训练好模型后,采样、生成图片

2.2 数学形式

2.2.1 前向过程

是真实数据分布(也就是真实的大量图片),从这个分布中采样即可得到一张真实图片 。我们定义前向扩散过程为 ,即每一个step向图片添加噪声的过程,并定义好一系列,则有:

其中,N为正态分布,均值和方差分别为,因此通过采样标准正态分布,有:

2.2.2 反向过程

那么问题的核心就是如何得到的逆过程 ,这个过程无法直接求出来,所以我们使用神经网络去拟合这一分布。我们使用一个具有参数的神经网络去计算 。假设反向的条件概率分布也是高斯分布,且高斯分布实际上只有两个参数:均值和方差,那么神经网络需要计算的实际上是

在DDPM中,方差被固定,网络只学习均值。而之后的改进模型中,方差也可由网络学习得到。

2.2.3 总结过程
总之,我们定义这么一个过程:给一张图片逐步加噪声直到变成纯粹的噪声,然后对噪声进行去噪得到真实的图片。所谓的扩散模型就是让神经网络学习这个去除噪声的方法。
所谓的加噪声,就是基于稍微干净的图片计算一个(多维)高斯分布(每个像素点都有一个高斯分布,且均值就是这个像素点的值,方差是预先定义的 ),然后从这个多维分布中抽样一个数据出来,这个数据就是加噪之后的结果。显然,如果方差非常非常小,那么每个抽样得到的像素点就和原本的像素点的值非常接近,也就是加了一个非常非常小的噪声。如果方差比较大,那么抽样结果就会和原本的结果差距较大。
去噪声也是同理,我们基于稍微噪声的图片 计算一个条件分布,我们希望从这个分布中抽样得到的是相比于 更加接近真实图片的稍微干净的图片。我们假设这样的条件分布是存在的,并且也是个高斯分布,那么我们只需要知道均值和方差就可以了。问题是这个均值和方差是无法直接计算的,所以用神经网络去学习近似这样一个高斯分布。

2.3 网络训练流程

我们最终要训练的实际上是一个噪声预测器。神经网络输出的噪声是,而真实的噪声取自于正态分布。则损失函数为:

预测网络方面,DDPM采用了 U-Net。

从而,网络的训练流程为:

  1. 我们接受一个随机的样本

  1. 我们随机从 1 到 T 采样一个 t;

  1. 我们从高斯分布采样一些噪声并且施加在输入上;

  1. 网络从被影响过后的噪声图片学习其被施加了的噪声。

  1. 代码

3.1 Network helpers

先是一些辅助函数和类。


    
    
  1. def exists( x):
  2. return x is not None
  3. # 有val时返回val,val为None时返回d
  4. def default( val, d):
  5. if exists(val):
  6. return val
  7. return d() if isfunction(d) else d
  8. # 残差模块,将输入加到输出上
  9. class Residual(nn.Module):
  10. def __init__( self, fn):
  11. super().__init__()
  12. self.fn = fn
  13. def forward( self, x, *args, **kwargs):
  14. return self.fn(x, *args, **kwargs) + x
  15. # 上采样(反卷积)
  16. def Upsample( dim):
  17. return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
  18. # 下采样
  19. def Downsample( dim):
  20. return nn.Conv2d(dim, dim, 4, 2, 1)
  • 1

3.2 Positional embeddings

类似于Transformer的positional embedding,为了让网络知道当前处理的是一系列去噪过程中的哪一个step,我们需要将步数 t 也编码并传入网络之中。DDPM采用正弦位置编码(Sinusoidal Positional Embeddings)。这一方法的输入是shape为 (batch_size, 1) 的 tensor,也就是batch中每一个sample所处的t ,并将这个tensor转换为shape为 (batch_size, dim) 的 tensor。这个tensor会被加到每一个残差模块中。


    
    
  1. class SinusoidalPositionEmbeddings(nn.Module):
  2. def __init__( self, dim):
  3. super().__init__()
  4. self.dim = dim
  5. def forward( self, time):
  6. device = time.device
  7. half_dim = self.dim // 2
  8. embeddings = math.log( 10000) / (half_dim - 1)
  9. embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
  10. embeddings = time[:, None] * embeddings[ None, :]
  11. embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=- 1)
  12. return embeddings
  • 1

3.3 ResNet/ConvNeXT block

U-Net的Block实现,可以用ResNet或ConvNeXT。


    
    
  1. class Block(nn.Module):
  2. def __init__( self, dim, dim_out, groups = 8):
  3. super().__init__()
  4. self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
  5. self.norm = nn.GroupNorm(groups, dim_out)
  6. self.act = nn.SiLU()
  7. def forward( self, x, scale_shift = None):
  8. x = self.proj(x)
  9. x = self.norm(x)
  10. if exists(scale_shift):
  11. scale, shift = scale_shift
  12. x = x * (scale + 1) + shift
  13. x = self.act(x)
  14. return x
  15. class ResnetBlock(nn.Module):
  16. """Deep Residual Learning for Image Recognition"""
  17. def __init__( self, dim, dim_out, *, time_emb_dim=None, groups=8):
  18. super().__init__()
  19. self.mlp = (
  20. nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
  21. if exists(time_emb_dim)
  22. else None
  23. )
  24. self.block1 = Block(dim, dim_out, groups=groups)
  25. self.block2 = Block(dim_out, dim_out, groups=groups)
  26. self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
  27. def forward( self, x, time_emb=None):
  28. h = self.block1(x)
  29. if exists(self.mlp) and exists(time_emb):
  30. time_emb = self.mlp(time_emb)
  31. h = rearrange(time_emb, "b c -> b c 1 1") + h
  32. h = self.block2(h)
  33. return h + self.res_conv(x)
  34. class ConvNextBlock(nn.Module):
  35. """A ConvNet for the 2020s"""
  36. def __init__( self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
  37. super().__init__()
  38. self.mlp = (
  39. nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
  40. if exists(time_emb_dim)
  41. else None
  42. )
  43. self.ds_conv = nn.Conv2d(dim, dim, 7, padding= 3, groups=dim)
  44. Get an email address at self.net. It 's ad-free, reliable email that's based on your own name | self.net = nn.Sequential(
  45. nn.GroupNorm( 1, dim) if norm else nn.Identity(),
  46. nn.Conv2d(dim, dim_out * mult, 3, padding= 1),
  47. nn.GELU(),
  48. nn.GroupNorm( 1, dim_out * mult),
  49. nn.Conv2d(dim_out * mult, dim_out, 3, padding= 1),
  50. )
  51. self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
  52. def forward( self, x, time_emb=None):
  53. h = self.ds_conv(x)
  54. if exists(self.mlp) and exists(time_emb):
  55. condition = self.mlp(time_emb)
  56. h = h + rearrange(condition, "b c -> b c 1 1")
  57. h = Get an email address at self.net. It 's ad-free, reliable email that's based on your own name | self.net(h)
  58. return h + self.res_conv(x)
  • 1

3.4 Attention module

包含两种attention模块,一个是常规的 multi-head self-attention,一个是 linear attention variant。


    
    
  1. class Attention(nn.Module):
  2. def __init__( self, dim, heads=4, dim_head=32):
  3. super().__init__()
  4. self.scale = dim_head**- 0.5
  5. self.heads = heads
  6. hidden_dim = dim_head * heads
  7. self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias= False)
  8. self.to_out = nn.Conv2d(hidden_dim, dim, 1)
  9. def forward( self, x):
  10. b, c, h, w = x.shape
  11. qkv = self.to_qkv(x).chunk( 3, dim= 1)
  12. q, k, v = map(
  13. lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
  14. )
  15. q = q * self.scale
  16. sim = einsum( "b h d i, b h d j -> b h i j", q, k)
  17. sim = sim - sim.amax(dim=- 1, keepdim= True).detach()
  18. attn = sim.softmax(dim=- 1)
  19. out = einsum( "b h i j, b h d j -> b h i d", attn, v)
  20. out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
  21. return self.to_out(out)
  22. class LinearAttention(nn.Module):
  23. def __init__( self, dim, heads=4, dim_head=32):
  24. super().__init__()
  25. self.scale = dim_head**- 0.5
  26. self.heads = heads
  27. hidden_dim = dim_head * heads
  28. self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias= False)
  29. self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
  30. nn.GroupNorm( 1, dim))
  31. def forward( self, x):
  32. b, c, h, w = x.shape
  33. qkv = self.to_qkv(x).chunk( 3, dim= 1)
  34. q, k, v = map(
  35. lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
  36. )
  37. q = q.softmax(dim=- 2)
  38. k = k.softmax(dim=- 1)
  39. q = q * self.scale
  40. context = torch.einsum( "b h d n, b h e n -> b h d e", k, v)
  41. out = torch.einsum( "b h d e, b h d n -> b h e n", context, q)
  42. out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
  43. return self.to_out(out)
  • 1

3.5 Group normalization

DDPM的作者对U-Net的卷积/注意力层使用GN正则化。下面,我们定义了一个PreNorm类,它将被用于在注意力层之前应用groupnorm。值得注意的是,归一化在Transformer中是在注意力之前还是之后应用,目前仍存在着争议。


    
    
  1. class PreNorm(nn.Module):
  2. def __init__( self, dim, fn):
  3. super().__init__()
  4. self.fn = fn
  5. self.norm = nn.GroupNorm( 1, dim)
  6. def forward( self, x):
  7. x = self.norm(x)
  8. return self.fn(x)
  • 1

3.6 Conditional U-Net

现在,我们已经定义了所有的组件,接下来就是定义完整的网络了。

输入:噪声图片的batch+这些图片各自的t。

输出:预测每个图片上所添加的噪声。

Input:a batch of noisy images of shape ( batch_size, num_channels, h, w ) and a batch of steps of shape ( batch_size, 1 )
output: a tensor of shape ( batch_size, num_channels, h, w )

具体的网络结构:

  1. 首先,输入通过一个卷积层,同时计算step t 所对应的embedding

  1. 通过一系列的下采样stage,每个stage都包含:2个ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + downsample operation

  1. 在网络中间,应用一个带attention的ResNet或者ConvNeXT

  1. 通过一系列的上采样stage,每个stage都包含:2个ResNet/ConvNeXT blocks + groupnorm + attention + residual connection + upsample operation

  1. 最终,通过一个ResNet/ConvNeXT blocl和一个卷积层。


    
    
  1. class Unet(nn.Module):
  2. def __init__(
  3. self,
  4. dim,
  5. init_dim=None,
  6. out_dim=None,
  7. dim_mults=(1, 2, 4, 8),
  8. channels=3,
  9. with_time_emb=True,
  10. resnet_block_groups=8,
  11. use_convnext=True,
  12. convnext_mult=2,
  13. ):
  14. super().__init__()
  15. # determine dimensions
  16. self.channels = channels
  17. init_dim = default(init_dim, dim // 3 * 2)
  18. self.init_conv = nn.Conv2d(channels, init_dim, 7, padding= 3)
  19. dims = [init_dim, * map( lambda m: dim * m, dim_mults)]
  20. in_out = list( zip(dims[:- 1], dims[ 1:]))
  21. if use_convnext:
  22. block_klass = partial(ConvNextBlock, mult=convnext_mult)
  23. else:
  24. block_klass = partial(ResnetBlock, groups=resnet_block_groups)
  25. # time embeddings
  26. if with_time_emb:
  27. time_dim = dim * 4
  28. self.time_mlp = nn.Sequential(
  29. SinusoidalPositionEmbeddings(dim),
  30. nn.Linear(dim, time_dim),
  31. nn.GELU(),
  32. nn.Linear(time_dim, time_dim),
  33. )
  34. else:
  35. time_dim = None
  36. self.time_mlp = None
  37. # layers
  38. self.downs = nn.ModuleList([])
  39. self.ups = nn.ModuleList([])
  40. num_resolutions = len(in_out)
  41. for ind, (dim_in, dim_out) in enumerate(in_out):
  42. is_last = ind >= (num_resolutions - 1)
  43. self.downs.append(
  44. nn.ModuleList(
  45. [
  46. block_klass(dim_in, dim_out, time_emb_dim=time_dim),
  47. block_klass(dim_out, dim_out, time_emb_dim=time_dim),
  48. Residual(PreNorm(dim_out, LinearAttention(dim_out))),
  49. Downsample(dim_out) if not is_last else nn.Identity(),
  50. ]
  51. )
  52. )
  53. mid_dim = dims[- 1]
  54. self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
  55. self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
  56. self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
  57. for ind, (dim_in, dim_out) in enumerate( reversed(in_out[ 1:])):
  58. is_last = ind >= (num_resolutions - 1)
  59. self.ups.append(
  60. nn.ModuleList(
  61. [
  62. block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
  63. block_klass(dim_in, dim_in, time_emb_dim=time_dim),
  64. Residual(PreNorm(dim_in, LinearAttention(dim_in))),
  65. Upsample(dim_in) if not is_last else nn.Identity(),
  66. ]
  67. )
  68. )
  69. out_dim = default(out_dim, channels)
  70. self.final_conv = nn.Sequential(
  71. block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
  72. )
  73. def forward( self, x, time):
  74. x = self.init_conv(x)
  75. t = self.time_mlp(time) if exists(self.time_mlp) else None
  76. h = []
  77. # downsample
  78. for block1, block2, attn, downsample in self.downs:
  79. x = block1(x, t)
  80. x = block2(x, t)
  81. x = attn(x)
  82. h.append(x)
  83. x = downsample(x)
  84. # bottleneck
  85. x = self.mid_block1(x, t)
  86. x = self.mid_attn(x)
  87. x = self.mid_block2(x, t)
  88. # upsample
  89. for block1, block2, attn, upsample in self.ups:
  90. x = torch.cat((x, h.pop()), dim= 1)
  91. x = block1(x, t)
  92. x = block2(x, t)
  93. x = attn(x)
  94. x = upsample(x)
  95. return self.final_conv(x)
  • 1

3.7 定义前向扩散过程

DDPM中使用linear schedule定义 后续的研究指出使用cosine schedule可能会有更好的效果。

接下来是一些简单的对于 schedule 的定义,从当中选一个使用即可。


    
    
  1. def cosine_beta_schedule( timesteps, s=0.008):
  2. """
  3. cosine schedule as proposed in https://arxiv.org/abs/2102.09672
  4. """
  5. steps = timesteps + 1
  6. x = torch.linspace( 0, timesteps, steps)
  7. alphas_cumprod = torch.cos(((x / timesteps) + s) / ( 1 + s) * torch.pi * 0.5) ** 2
  8. alphas_cumprod = alphas_cumprod / alphas_cumprod[ 0]
  9. betas = 1 - (alphas_cumprod[ 1:] / alphas_cumprod[:- 1])
  10. return torch.clip(betas, 0.0001, 0.9999)
  11. def linear_beta_schedule( timesteps):
  12. beta_start = 0.0001
  13. beta_end = 0.02
  14. return torch.linspace(beta_start, beta_end, timesteps)
  15. def quadratic_beta_schedule( timesteps):
  16. beta_start = 0.0001
  17. beta_end = 0.02
  18. return torch.linspace(beta_start** 0.5, beta_end** 0.5, timesteps) ** 2
  19. def sigmoid_beta_schedule( timesteps):
  20. beta_start = 0.0001
  21. beta_end = 0.02
  22. betas = torch.linspace(- 6, 6, timesteps)
  23. return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
  • 1

我们按照DDPM中用第二种的linear,将 T 设置为200,并将每个 t 下的各种参数提前计算好。


    
    
  1. timesteps = 200
  2. # define beta schedule
  3. betas = linear_beta_schedule(timesteps=timesteps)
  4. # define alphas
  5. alphas = 1. - betas
  6. alphas_cumprod = torch.cumprod(alphas, axis= 0)
  7. alphas_cumprod_prev = F.pad(alphas_cumprod[:- 1], ( 1, 0), value= 1.0)
  8. sqrt_recip_alphas = torch.sqrt( 1.0 / alphas)
  9. # calculations for diffusion q(x_t | x_{t-1}) and others
  10. sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
  11. sqrt_one_minus_alphas_cumprod = torch.sqrt( 1. - alphas_cumprod)
  12. # calculations for posterior q(x_{t-1} | x_t, x_0)
  13. posterior_variance = betas * ( 1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod)
  14. def extract( a, t, x_shape):
  15. batch_size = t.shape[ 0]
  16. out = a.gather(- 1, t.cpu())
  17. return out.reshape(batch_size, *(( 1,) * ( len(x_shape) - 1))).to(t.device)
  • 1

我们用一个实例来说明前向加噪过程。


    
    
  1. from PIL import Image
  2. import requests
  3. url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
  4. image = Image. open(requests.get(url, stream= True).raw)
  5. image
  • 1

    
    
  1. from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
  2. image_size = 128
  3. transform = Compose([
  4. Resize(image_size),
  5. CenterCrop(image_size),
  6. ToTensor(), # turn into Numpy array of shape HWC, divide by 255
  7. Lambda( lambda t: (t * 2) - 1),
  8. ])
  9. x_start = transform(image).unsqueeze( 0)
  10. x_start.shape # 输出的结果是 torch.Size([1, 3, 128, 128])
  11. import numpy as np
  12. reverse_transform = Compose([
  13. Lambda( lambda t: (t + 1) / 2),
  14. Lambda( lambda t: t.permute( 1, 2, 0)), # CHW to HWC
  15. Lambda( lambda t: t * 255.),
  16. Lambda( lambda t: t.numpy().astype(np.uint8)),
  17. ToPILImage(),
  18. ])
  • 1

准备齐全,接下来就可以定义正向扩散过程了。


    
    
  1. # forward diffusion (using the nice property)
  2. def q_sample( x_start, t, noise=None):
  3. if noise is None:
  4. noise = torch.randn_like(x_start)
  5. sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
  6. sqrt_one_minus_alphas_cumprod_t = extract(
  7. sqrt_one_minus_alphas_cumprod, t, x_start.shape
  8. )
  9. return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
  10. def get_noisy_image( x_start, t):
  11. # add noise
  12. x_noisy = q_sample(x_start, t=t)
  13. # turn back into PIL image
  14. noisy_image = reverse_transform(x_noisy.squeeze())
  15. return noisy_image
  • 1

可视化一下多个不同t的生成结果。


    
    
  1. import matplotlib.pyplot as plt
  2. # use seed for reproducability
  3. torch.manual_seed( 0)
  4. # source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
  5. def plot( imgs, with_orig=False, row_title=None, **imshow_kwargs):
  6. if not isinstance(imgs[ 0], list):
  7. # Make a 2d grid even if there's just 1 row
  8. imgs = [imgs]
  9. num_rows = len(imgs)
  10. num_cols = len(imgs[ 0]) + with_orig
  11. fig, axs = plt.subplots(figsize=( 200, 200), nrows=num_rows, ncols=num_cols, squeeze= False)
  12. for row_idx, row in enumerate(imgs):
  13. row = [image] + row if with_orig else row
  14. for col_idx, img in enumerate(row):
  15. ax = axs[row_idx, col_idx]
  16. ax.imshow(np.asarray(img), **imshow_kwargs)
  17. ax. set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
  18. if with_orig:
  19. axs[ 0, 0]. set(title= 'Original image')
  20. axs[ 0, 0].title.set_size( 8)
  21. if row_title is not None:
  22. for row_idx in range(num_rows):
  23. axs[row_idx, 0]. set(ylabel=row_title[row_idx])
  24. plt.tight_layout()
  25. plot([get_noisy_image(x_start, torch.tensor([t])) for t in [ 0, 50, 100, 150, 199]])
  • 1

3.8 定义损失函数


    
    
  1. def p_losses( denoise_model, x_start, t, noise=None, loss_type="l1"):
  2. # 先采样噪声
  3. if noise is None:
  4. noise = torch.randn_like(x_start)
  5. # 用采样得到的噪声去加噪图片
  6. x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
  7. predicted_noise = denoise_model(x_noisy, t)
  8. # 根据加噪了的图片去预测采样的噪声
  9. if loss_type == 'l1':
  10. loss = F.l1_loss(noise, predicted_noise)
  11. elif loss_type == 'l2':
  12. loss = F.mse_loss(noise, predicted_noise)
  13. elif loss_type == "huber":
  14. loss = F.smooth_l1_loss(noise, predicted_noise)
  15. else:
  16. raise NotImplementedError()
  17. return loss
  • 1

3.9 定义数据集 PyTorch Dataset 和 DataLoader

我们使用mnist数据集构造了一个 DataLoader,每个batch由128张 normalize 过的 image 组成。


    
    
  1. from datasets import load_dataset
  2. # load dataset from the hub
  3. dataset = load_dataset( "fashion_mnist")
  4. image_size = 28
  5. channels = 1
  6. batch_size = 128
  7. from torchvision import transforms
  8. from torch.utils.data import DataLoader
  9. transform = Compose([
  10. transforms.RandomHorizontalFlip(),
  11. transforms.ToTensor(),
  12. transforms.Lambda( lambda t: (t * 2) - 1)
  13. ])
  14. def transforms( examples):
  15. examples[ "pixel_values"] = [transform(image.convert( "L")) for image in examples[ "image"]]
  16. del examples[ "image"]
  17. return examples
  18. transformed_dataset = dataset.with_transform(transforms).remove_columns( "label")
  19. dataloader = DataLoader(transformed_dataset[ "train"], batch_size=batch_size, shuffle= True)
  20. batch = next( iter(dataloader))
  21. print(batch.keys())    # dict_keys(['pixel_values'])
  • 1

3.10 采样

采样过程发生在反向去噪时。对于一张纯噪声,扩散模型一步步地去除噪声最终得到真实图片,采样事实上就是定义的去除噪声这一行为。 观察采样算法中第四行, t−1 步的图片是由 t 步的图片减去一个噪声得到的,只不过这个噪声是由网络拟合出来,并且 rescale 过的而已。 这里要注意第四行式子的最后一项,采样时每一步也都会加上一个从正态分布采样的纯噪声。理想情况下,最终我们会得到一张看起来像是从真实数据分布中采样得到的图片。


    
    
  1. @torch.no_grad()
  2. def p_sample( model, x, t, t_index):
  3. betas_t = extract(betas, t, x.shape)
  4. sqrt_one_minus_alphas_cumprod_t = extract(
  5. sqrt_one_minus_alphas_cumprod, t, x.shape
  6. )
  7. sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
  8. # Equation 11 in the paper
  9. # Use our model (noise predictor) to predict the mean
  10. model_mean = sqrt_recip_alphas_t * (
  11. x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
  12. )
  13. if t_index == 0:
  14. return model_mean
  15. else:
  16. posterior_variance_t = extract(posterior_variance, t, x.shape)
  17. noise = torch.randn_like(x)
  18. # Algorithm 2 line 4:
  19. return model_mean + torch.sqrt(posterior_variance_t) * noise
  20. # Algorithm 2 (including returning all images)
  21. @torch.no_grad()
  22. def p_sample_loop( model, shape):
  23. device = next(model.parameters()).device
  24. b = shape[ 0]
  25. # start from pure noise (for each example in the batch)
  26. img = torch.randn(shape, device=device)
  27. imgs = []
  28. for i in tqdm( reversed( range( 0, timesteps)), desc= 'sampling loop time step', total=timesteps):
  29. img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
  30. imgs.append(img.cpu().numpy())
  31. return imgs
  32. @torch.no_grad()
  33. def sample( model, image_size, batch_size=16, channels=3):
  34. return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
  • 1

3.11 训练

先定义一些辅助生成图片的函数。


    
    
  1. from pathlib import Path
  2. def num_to_groups( num, divisor):
  3. groups = num // divisor
  4. remainder = num % divisor
  5. arr = [divisor] * groups
  6. if remainder > 0:
  7. arr.append(remainder)
  8. return arr
  9. results_folder = Path( "./results")
  10. results_folder.mkdir(exist_ok = True)
  11. save_and_sample_every = 1000
  • 1

接下来实例化模型。


    
    
  1. from torch.optim import Adam
  2. device = "cuda" if torch.cuda.is_available() else "cpu"
  3. model = Unet(
  4. dim=image_size,
  5. channels=channels,
  6. dim_mults=( 1, 2, 4,)
  7. )
  8. model.to(device)
  9. optimizer = Adam(model.parameters(), lr= 1e-3)
  • 1

开始训练!


    
    
  1. from torchvision.utils import save_image
  2. epochs = 6
  3. for epoch in range(epochs):
  4. for step, batch in enumerate(dataloader):
  5. optimizer.zero_grad()
  6. batch_size = batch[ "pixel_values"].shape[ 0]
  7. batch = batch[ "pixel_values"].to(device)
  8. # Algorithm 1 line 3: sample t uniformally for every example in the batch
  9. t = torch.randint( 0, timesteps, (batch_size,), device=device).long()
  10. loss = p_losses(model, batch, t, loss_type= "huber")
  11. if step % 100 == 0:
  12. print( "Loss:", loss.item())
  13. loss.backward()
  14. optimizer.step()
  15. # save generated images
  16. if step != 0 and step % save_and_sample_every == 0:
  17. milestone = step // save_and_sample_every
  18. batches = num_to_groups( 4, batch_size)
  19. all_images_list = list( map( lambda n: sample(model, batch_size=n, channels=channels), batches))
  20. all_images = torch.cat(all_images_list, dim= 0)
  21. all_images = (all_images + 1) * 0.5
  22. save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)
  • 1

Inference:


    
    
  1. # sample 64 images
  2. samples = sample(model, image_size=image_size, batch_size= 64, channels=channels)
  3. # show a random one
  4. random_index = 5
  5. plt.imshow(samples[- 1][random_index].reshape(image_size, image_size, channels), cmap= "gray")
  • 1

    
    
  1. import matplotlib.animation as animation
  2. random_index = 53
  3. fig = plt.figure()
  4. ims = []
  5. for i in range(timesteps):
  6. im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap= "gray", animated= True)
  7. ims.append([im])
  8. animate = animation.ArtistAnimation(fig, ims, interval= 50, blit= True, repeat_delay= 1000)
  9. animate.save( 'diffusion.gif')
  10. plt.show()
  • 1

4. 参考文献

原理+代码:Diffusion Model 直观理解

The Annotated Diffusion Model

【diffusion】扩散模型详解!理论+代码

原文链接:
https://blog.csdn.net/tobefans/article/details/129728036

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/479329
推荐阅读
相关标签
  

闽ICP备14008679号