当前位置:   article > 正文

扩散模型(三):基于UNet2DModel的扩散模型算法搭建

unet2dmodel

前言:学习了扩散模型从原理到实战-异步社区-致力于优质IT知识的出版和分享 (epubit.com)这本教材后,对教材里所提的内容进行了自我消化,总结总结。

UNet2DModel比BasicUNet模型根伟先进,相比于BasicUnet做了如下改进:

  • GroupNorm层对每个模块的输入进行组标准化
  • Dropout层使的训练更加平滑
  • 每个块具有多个ResNet层
  • 引入了注意力机制
  • 可以对时间步进行调节
  • 具有可学习参数的上采样模块和下采样模块

创建一个UNet2DModel模块:

  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from torch.utils.data import DataLoader
  6. from diffusers import DDPMScheduler, UNet2DModel
  7. from matplotlib import pyplot as plt
  8. net = UNet2DModel(
  9. sample_size=28, # the target image resolution
  10. in_channels=1, # the number of input channels, 3 for RGB images
  11. out_channels=1, # the number of output channels
  12. layers_per_block=2, # how many ResNet layers to use per UNet block
  13. block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
  14. down_block_types=(
  15. "DownBlock2D", # a regular ResNet downsampling block
  16. "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
  17. "AttnDownBlock2D",
  18. ),
  19. up_block_types=(
  20. "AttnUpBlock2D",
  21. "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
  22. "UpBlock2D", # a regular ResNet upsampling block
  23. ),
  24. )
  25. print(net)

很明显,UNet2DModel模块相比于BasicUNet更为复杂,使用了约170万个参数(BasicUNet使用了30多万个)。只需要在扩散模型基础(一):基于BasicUNet的扩散模型算法搭建-CSDN博客 算法的基础上,将BasicUNet更换为UNet2DModel,并修改采样预测的过程即可。

调试好的全套算法如下:

  1. import torch
  2. import torchvision
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from torch.utils.data import DataLoader
  6. from diffusers import DDPMScheduler, UNet2DModel
  7. from matplotlib import pyplot as plt
  8. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  9. print(f'Using device: {device}')
  10. dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
  11. train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
  12. # Dataloader (you can mess with batch size)
  13. batch_size = 128
  14. train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
  15. # How many runs through the data should we do?
  16. n_epochs = 3
  17. def corrupt(x, amount):
  18. """Corrupt the input `x` by mixing it with noise according to `amount`"""
  19. #print(amount)
  20. noise = torch.rand_like(x)
  21. #print(noise)
  22. amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
  23. #print(amount)
  24. return x*(1-amount) + noise*amount
  25. # Create the network
  26. net = UNet2DModel(
  27. sample_size=28, # the target image resolution
  28. in_channels=1, # the number of input channels, 3 for RGB images
  29. out_channels=1, # the number of output channels
  30. layers_per_block=2, # how many ResNet layers to use per UNet block
  31. block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
  32. down_block_types=(
  33. "DownBlock2D", # a regular ResNet downsampling block
  34. "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
  35. "AttnDownBlock2D",
  36. ),
  37. up_block_types=(
  38. "AttnUpBlock2D",
  39. "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
  40. "UpBlock2D", # a regular ResNet upsampling block
  41. ),
  42. ) #<<<
  43. net.to(device)
  44. # Our loss finction
  45. loss_fn = nn.MSELoss()
  46. # The optimizer
  47. opt = torch.optim.Adam(net.parameters(), lr=1e-3)
  48. # Keeping a record of the losses for later viewing
  49. losses = []
  50. # The training loop
  51. for epoch in range(n_epochs):
  52. for x, y in train_dataloader:
  53. # Get some data and prepare the corrupted version
  54. x = x.to(device) # Data on the GPU
  55. noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
  56. noisy_x = corrupt(x, noise_amount) # Create our noisy x
  57. # Get the model prediction
  58. pred = net(noisy_x, 0).sample #<<< Using timestep 0 always, adding .sample
  59. # Calculate the loss
  60. loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
  61. # Backprop and update the params:
  62. opt.zero_grad()
  63. loss.backward()
  64. opt.step()
  65. # Store the loss for later
  66. losses.append(loss.item())
  67. # Print our the average of the loss values for this epoch:
  68. avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
  69. print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')
  70. # Plot losses and some samples
  71. fig, axs = plt.subplots(1, 2, figsize=(12, 5))
  72. # Losses
  73. axs[0].plot(losses)
  74. axs[0].set_ylim(0, 0.1)
  75. axs[0].set_title('Loss over time')
  76. # Samples
  77. n_steps = 40
  78. x = torch.rand(64, 1, 28, 28).to(device)
  79. for i in range(n_steps):
  80. noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low
  81. with torch.no_grad():
  82. pred = net(x, 0).sample
  83. mix_factor = 1/(n_steps - i)
  84. x = x*(1-mix_factor) + pred*mix_factor
  85. axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')
  86. axs[1].set_title('Generated Samples');
  87. plt.show()

运行结果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/771708
推荐阅读
相关标签
  

闽ICP备14008679号