赞
踩
参加了Datawhale组织的扩散模型学习活动,完成了第一单元的学习工作,这里简单记录下学习所得。
扩散模型从本质上来说是一种特殊的生成模型,所以在具体展开之前,先学习生成模型的基本定义。
相比于判别模型最终要输出一个具体的类别或者预测值的方式,生成模型是来预测一组数据的概率分布,在给定训练数据x后,假设数据服从某真实分布p(x),则训练集可看作是从中采样的观测样本x,模型的作用就是通过训练集去学习估计这种真实的分布。
而扩散模型就属于其中一种新的模型。其利用了物理学中的扩散思想,严格来说包括了前向扩散(增加噪声)和反向去噪(减少噪声)两个过程。
下面以经典的DDPM(Denoising Diffusion Probabilistic Model)为例进行说明,大致过程可以概括为下图。
可以看出,其实现过程是每一个时间步进行迭代。生成从随机噪声开始,但通过多个步骤逐渐细化,直到出现输出图像。在每一步中,模型都会估计我们如何从当前输入变为完全去噪的版本。但是,由于我们只在每一步都做一个小的更改,因此在早期阶段(预测最终输出非常困难)此估计中的任何错误都可以在以后的更新中得到纠正。
训练模型的过程相对简单,主要分为以下五步:
1、从训练数据加载一些图像
2、添加不同数量的噪音。请记住,我们希望模型能够很好地估计如何“修复”(降噪)极端嘈杂的图像和接近完美的图像。
3、将输入的噪声版本馈送到模型中
4、评估模型在对这些输入进行降噪方面的表现
5、使用此信息更新模型权重
简单来说,就是让训练集中干净的图片先通过添加噪声的方式变成噪声版本的图片,再送到模型中,让它预测噪声或者降噪后的图片,与干净图片作对比,以更新权重。
当要使用经过训练的模型生成新图像时,我们从一个完全随机的输入开始,并通过模型重复馈送它,每次根据模型预测对其进行少量更新。并通过足够的优化手段,在较少的时间步内逐渐生成完整的图像。
具体数学原理如下:
基础扩散模型的提出与改进:最早提出的扩散模型是DDPM,将去噪扩散概率模型应用到图像生成任务中。
采样器:通过离散化求解随机微分方程,降低采样步数。
基于相似分类器引导的扩散模型:OpenAI的《Diffusion Models Beat GANs on Image Synthesis》论文介绍了在扩散过程中如何显式分类器引导。
基于CLIP的多模态图像生成:将同一语义的文字和图片转换到同一个隐空间中。
大模型的“再学习”方法:DreamBooth实现现有模型再学习到指定主体图像的功能,通过少量训练将主体绑定到唯一的文本标识符后,通过输入prompt控制主体生成不同的图像。LoRA可以指定数据集风格或人物,并将其融入现有的图像生成中。ControlNet学习多模态的信息,利用分割图、边缘图更精细地控制图像生成。
AI作画:Midjoryney、DreamStudio、Adobe Firefly,以及百度的文心一格AI创作平台,阿里的通义文生图大模型。
计算机视觉:图像分割与目标检测、图像超分辨率(串联多个扩散模型)、图像修复、图像翻译和图像编辑。
时序数据预测:TimeGrad模型,使用RNN处理历史数据并保存到隐空间,对数据添加噪声实现扩散过程,处理数千维度德多元数据完成预测。
自然语言:使用Diffusion-LM可以应用在语句生成、语言翻译、问答对话、搜索补全、情感分析、文章续写等任务中。
基于文本的多模态:文本生成图像(DALLE-2、Imagen、Stable Diffusion)、文本生成视频(Make-A-Video、ControlNet Video)、文本生成3D(DiffRF)
AI基础科学:SMCDiff(支架蛋白质生成)、CDVAE(扩散晶体变分自编码器模型)
通过以上对扩散模型原理的了解,现在开始对扩散模型的搭建做下探索。从最简单的模型先开始,本次采用在Google colab环境下进行搭建和运行。
1、设置和导入:
!pip install -q diffusers
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
Using device: cuda
2、设置加载数据集。这里采用规模较小的minist数据集,也可以采用fashionminist来增加训练难度。
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
简单查看下数据集:
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])
每张图片都是28*28的灰度图,像素值在0-1之间。
3、初步探索腐蚀(加噪声)过程,假设不知道任何论文的内容,直观认为加噪声的方法是定义一个线性系数来控制加噪声的数量,这里以amount表示,则可以这样做:
noise = torch.rand_like(x)#创建一个和输入图像相同尺寸的噪声数据
noisy_x = (1-amount)*x + amount*noise#以amount比例保留噪声得到噪声图像
直观可以看出当amount=0时,我们不保留噪声,当amount=1时,就全部是噪声,这样在(0,1)之间我们就可以将输入和噪声混合,得到噪声图像。
定义噪声函数,简单查看下结果:
def corrupt(x, amount):
"""Corrupt the input `x` by mixing it with noise according to `amount`"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x*(1-amount) + noise*amount
# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
# Adding noise
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# Plotting the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');
可以直观的看出随着噪声系数的增加,图片数据逐渐向随机噪声变化,这样的设计导致有些加噪图片很容易辨认出来,这显然是不合理的,在后文会详细阐述为什么。
4、建立模型,模型接受28*28的嘈杂图像并输出相同形状的预测。这里的一个流行选择是一种称为UNet的架构。UNet最初是为医学图像中的分割任务而发明的,它由一个“收缩路径”和一个“扩展路径”组成,通过该路径,数据被压缩回原始维度(类似于自动编码器),但也具有跳过连接,允许信息和梯度在不同水平上流动。
为简单起见,这里采用一种简化的形式。该示例获取单通道图像并将其传递到向下路径上的三个卷积层(图表和代码中的down_layers)和向上路径上的三个卷积层,并在向下层和向上层之间跳过连接。我们将使用最大池化进行下采样和 nn.Upsample 上采样,而不是像更复杂的 UNet 那样依赖可学习的层。以下是显示每层输出中的通道数的粗略架构:
代码如下:
class BasicUNet(nn.Module): """A minimal UNet implementation.""" def __init__(self, in_channels=1, out_channels=1): super().__init__() self.down_layers = torch.nn.ModuleList([ nn.Conv2d(in_channels, 32, kernel_size=5, padding=2), nn.Conv2d(32, 64, kernel_size=5, padding=2), nn.Conv2d(64, 64, kernel_size=5, padding=2), ]) self.up_layers = torch.nn.ModuleList([ nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.Conv2d(64, 32, kernel_size=5, padding=2), nn.Conv2d(32, out_channels, kernel_size=5, padding=2), ]) self.act = nn.SiLU() # The activation function self.downscale = nn.MaxPool2d(2) self.upscale = nn.Upsample(scale_factor=2) def forward(self, x): h = [] for i, l in enumerate(self.down_layers): x = self.act(l(x)) # Through the layer and the activation function if i < 2: # For all but the third (final) down layer: h.append(x) # Storing output for skip connection x = self.downscale(x) # Downscale ready for the next layer for i, l in enumerate(self.up_layers): if i > 0: # For all except the first up layer x = self.upscale(x) # Upscale x += h.pop() # Fetching stored output (skip connection) x = self.act(l(x)) # Through the layer and the activation function return x
简单检验输出:
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
torch.Size([8, 1, 28, 28])
查看网络参数
sum([p.numel() for p in net.parameters()])
309057
5、训练网络
大致过程为:1、获取一批数据。2、通过随机数量加噪声。3、将噪声数据输入到模型中。4、将模型预测与干净图像进行比较以计算我们的损失
5、相应地更新模型的参数
代码如下:
# Dataloader (you can mess with batch size) batch_size = 128 train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # How many runs through the data should we do? n_epochs = 3 # Create the network net = BasicUNet() net.to(device) # Our loss function loss_fn = nn.MSELoss() # The optimizer opt = torch.optim.Adam(net.parameters(), lr=1e-3) # Keeping a record of the losses for later viewing losses = [] # The training loop for epoch in range(n_epochs): for x, y in train_dataloader: # Get some data and prepare the corrupted version x = x.to(device) # Data on the GPU noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts noisy_x = corrupt(x, noise_amount) # Create our noisy x # Get the model prediction pred = net(noisy_x) # Calculate the loss loss = loss_fn(pred, x) # How close is the output to the true 'clean' x? # Backprop and update the params: opt.zero_grad() loss.backward() opt.step() # Store the loss for later losses.append(loss.item()) # Print our the average of the loss values for this epoch: avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader) print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}') # View the loss curve plt.plot(losses) plt.ylim(0, 0.1);
验证模型输出,抓取批数据,将其损坏不同的数量,然后查看模型预测来查看模型预测的外观:
#@markdown Visualizing model predictions on noisy inputs: # Fetch some data x, y = next(iter(train_dataloader)) x = x[:8] # Only using the first 8 for easy plotting # Corrupt with a range of amounts amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption noised_x = corrupt(x, amount) # Get the model predictions with torch.no_grad(): preds = net(noised_x.to(device)).detach().cpu() # Plot fig, axs = plt.subplots(3, 1, figsize=(12, 7)) axs[0].set_title('Input data') axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys') axs[1].set_title('Corrupted data') axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys') axs[2].set_title('Network Predictions') axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');
可以看到模型在低噪声图片下,预测较好,但是随着噪声比例的上升,预测非常差,尤其当amount=1时,预测几乎都相当于以一半的几率在猜是噪声还是图像了。这一切的问题还是归咎于前面的采样生成过程不够平滑,所以需要改进采样过程。
改进思路是不让模型一步就得到预测结果,而是给他一个迭代步的过程,就好比人的学习要循序渐进,欲速则不达。
从随机噪声开始,查看模型预测,但随后只向该预测移动少量 - 例如,20%的预测。现在我们有一个非常嘈杂的图像,其中可能有一丝结构,我们可以将其输入到模型中以获得新的预测。希望这个新的预测比第一个稍微好一点(因为我们的起点稍微不那么嘈杂),所以我们可以用这个新的、更好的预测再迈出一小步,即逐步提高预测部分的占比。
这里给出一个五迭代步的例子:
@markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time: n_steps = 5 x = torch.rand(8, 1, 28, 28).to(device) # Start from random step_history = [x.detach().cpu()] pred_output_history = [] for i in range(n_steps): with torch.no_grad(): # No need to track gradients during inference pred = net(x) # Predict the denoised x0 pred_output_history.append(pred.detach().cpu()) # Store model output for plotting mix_factor = 1/(n_steps - i) # How much we move towards the prediction x = x*(1-mix_factor) + pred*mix_factor # Move part of the way there step_history.append(x.detach().cpu()) # Store step for plotting fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True) axs[0,0].set_title('x (model input)') axs[0,1].set_title('model prediction') for i in range(n_steps): axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys') axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')
左边为输入图片,右边为预测图片,可以看到经过多次迭代,模型预测输出正在变好。
以此思路,扩大采样步数,可以获得更好的输出图像:
#@markdown Showing more results, using 40 sampling steps
n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low
with torch.no_grad():
pred = net(x)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')
虽然效果仍然不好,但至少去噪数字有的已经可以被识别出来了。可以训练更长时间来获得更好的效果。
下面聊聊实际的DDPM模型与上述简化模型实现思路有何不同。
首先在扩散器模型的使用上,DDPM使用了更好的UNet2DModel,其主要该进有:
1、GroupNorm 将组规范化应用于每个块的输入
2、利用Dropout层,减少了模型的过拟合,实现更好的训练
3、每个块多个 resnet 层(如果 layers_per_block 未设置为 1)
4、添加了空间注意力层(通常仅用于较低分辨率的块)
5、上采样和下采样模块为可学习的层。
6、时间步设置了条件,即将时间步长因素嵌入到模型输入中
模型如下:
model = UNet2DModel( sample_size=28, # the target image resolution in_channels=1, # the number of input channels, 3 for RGB images out_channels=1, # the number of output channels layers_per_block=2, # how many ResNet layers to use per UNet block block_out_channels=(32, 64, 64), # Roughly matching our basic unet example down_block_types=( "DownBlock2D", # a regular ResNet downsampling block "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention "AttnDownBlock2D", ), up_block_types=( "AttnUpBlock2D", "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention "UpBlock2D", # a regular ResNet upsampling block ), ) print(model)
参数数量:
sum([p.numel() for p in model.parameters()]) # 1.7M vs the ~309k parameters of the BasicUNet
1707009
可见参数数量比之前大很多。
训练过程(这里关闭时间步控制,统一以时间步0为输入):
#@markdown Trying UNet2DModel instead of BasicUNet: # Dataloader (you can mess with batch size) batch_size = 128 train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # How many runs through the data should we do? n_epochs = 3 # Create the network net = UNet2DModel( sample_size=28, # the target image resolution in_channels=1, # the number of input channels, 3 for RGB images out_channels=1, # the number of output channels layers_per_block=2, # how many ResNet layers to use per UNet block block_out_channels=(32, 64, 64), # Roughly matching our basic unet example down_block_types=( "DownBlock2D", # a regular ResNet downsampling block "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention "AttnDownBlock2D", ), up_block_types=( "AttnUpBlock2D", "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention "UpBlock2D", # a regular ResNet upsampling block ), ) #<<< net.to(device) # Our loss finction loss_fn = nn.MSELoss() # The optimizer opt = torch.optim.Adam(net.parameters(), lr=1e-3) # Keeping a record of the losses for later viewing losses = [] # The training loop for epoch in range(n_epochs): for x, y in train_dataloader: # Get some data and prepare the corrupted version x = x.to(device) # Data on the GPU noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts noisy_x = corrupt(x, noise_amount) # Create our noisy x # Get the model prediction pred = net(noisy_x, 0).sample #<<< Using timestep 0 always, adding .sample # Calculate the loss loss = loss_fn(pred, x) # How close is the output to the true 'clean' x? # Backprop and update the params: opt.zero_grad() loss.backward() opt.step() # Store the loss for later losses.append(loss.item()) # Print our the average of the loss values for this epoch: avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader) print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}') # Plot losses and some samples fig, axs = plt.subplots(1, 2, figsize=(12, 5)) # Losses axs[0].plot(losses) axs[0].set_ylim(0, 0.1) axs[0].set_title('Loss over time') # Samples n_steps = 40 x = torch.rand(64, 1, 28, 28).to(device) for i in range(n_steps): noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low with torch.no_grad(): pred = net(x, 0).sample mix_factor = 1/(n_steps - i) x = x*(1-mix_factor) + pred*mix_factor axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys') axs[1].set_title('Generated Samples');
可以看出训练结果有了质的飞跃。
下面对加噪声的原理做下探讨。
可视化一下两个系数随时间步的变化:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");
最初,噪声x主要是x(sqrt_alpha_prod~= 1),但随着时间的推移,x的贡献下降,噪声分量增加。不像我们根据 的 amount x 和噪声的线性混合,这个噪声相对较快。我们可以在一些数据上可视化这一点:
#@markdown visualize the DDPM noising process for different timesteps: # Noise a batch of images to view the effect fig, axs = plt.subplots(3, 1, figsize=(16, 10)) xb, yb = next(iter(train_dataloader)) xb = xb.to(device)[:8] xb = xb * 2. - 1. # Map to (-1, 1) print('X shape', xb.shape) # Show clean inputs axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap='Greys') axs[0].set_title('Clean X') # Add noise with scheduler timesteps = torch.linspace(0, 999, 8).long().to(device) noise = torch.randn_like(xb) # << NB: randn not rand noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps) print('Noisy X shape', noisy_xb.shape) # Show noisy version (with and without clipping) axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap='Greys') axs[1].set_title('Noisy X (clipped to (-1, 1)') axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(), cmap='Greys') axs[2].set_title('Noisy X');
另一个不同的地方是DDPM版本增加了从高斯分布(来自torch.randn的平均值0,s.d. 1)中提取的噪声,而不是我们在原始 corrupt 函数中使用的0和1(来自torch.rand)之间的均匀噪声。通常,规范化训练数据也是有意义的。
上述代码和原论文还有一个差别是,原论文预测的是噪声间的误差,而上面是输出图片的误差。如果是预测噪声代码如下:
noise = torch.randn_like(xb) # << NB: randn not rand
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
model_prediction = model(noisy_x, timesteps).sample
loss = mse_loss(model_prediction, noise) # noise as the target
然后是上面的时间步控制,模型同时输入了时间步作为参数。这背后的理论是,通过向模型提供有关噪声水平的信息,它可以更好地执行其任务。虽然可以在没有这种时间步条件的情况下训练模型,但在某些情况下,它似乎确实有助于性能,并且大多数实现都包含它,至少在当前文献中是这样。
预测过程和简化版本一样,要迭代多次,每次消除一点噪声,最后迭代多步输出结果。
经过上述从零开始的扩散模型探索,这里利用huggingface里面的diffusers库进行下实战学习。
1、安装配置环境:
%pip install -qq -U diffusers datasets transformers accelerate ftfy pyarrow==9.0.0
from huggingface_hub import notebook_login
notebook_login()
%%capture
!sudo apt -qq install git-lfs
!git config --global credential.helper store
2、导入库与定义函数:
import numpy as np import torch import torch.nn.functional as F from matplotlib import pyplot as plt from PIL import Image def show_images(x): """Given a batch of images x, make a grid and convert to PIL""" x = x * 0.5 + 0.5 # Map from (-1, 1) back to (0, 1) grid = torchvision.utils.make_grid(x) grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255 grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8)) return grid_im def make_grid(images, size=64):#划分图像为不同网格 """Given a list of PIL images, stack them together into a line for easy viewing""" output_im = Image.new("RGB", (size * len(images), size)) for i, im in enumerate(images): output_im.paste(im.resize((size, size)), (i * size, 0)) return output_im # Mac users may need device = 'mps' (untested) device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3、加载预训练模型
from diffusers import StableDiffusionPipeline
# Check out https://huggingface.co/sd-dreambooth-library for loads of models from the community
model_id = "sd-dreambooth-library/mr-potato-head"
# Load the pipeline
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(
device
)
4、生成图像:
prompt = "an abstract oil painting of sks mr potato head by picasso"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image
可以看到生成了一幅土豆头先生的图像。其中pipe里面的两个参数num_inference_step可以用来调节模型的推理步数,即上面所说的迭代步,越大生成效果越好,但计算量也越大。guidance_scale指prompt所占生成的比例,越高模型生成则越依靠于提示词,越低则生成范围越自由。
Diffusers的核心API大致可以分为以下三种:
1、Piplines:高级类,旨在以用户友好的方式从流行的训练扩散模型中快速生成样本
2、Models:用于训练新扩散模型的流行架构,例如UNet。
3、Schedulers:用于在推理过程中从噪声生成图像以及生成用于训练的噪声图像的各种技术。
再举一个Pipelines的例子加深理解,生成蝴蝶图片:
from diffusers import DDPMPipeline
# Load the butterfly pipeline
butterfly_pipeline = DDPMPipeline.from_pretrained(
"johnowhitaker/ddpm-butterflies-32px"
).to(device)
# Create 8 images
images = butterfly_pipeline(batch_size=8).images
# View the result
make_grid(images)
虽然没之前好,但训练数据集量比之前少。
下面介绍一个完整的过程。
1、下载训练数据集,这里使用huggingface的1000张蝴蝶图片:
import torchvision from datasets import load_dataset from torchvision import transforms dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train") # Or load images from a local folder # dataset = load_dataset("imagefolder", data_dir="path/to/folder") # We'll train on 32-pixel square images, but you can try larger sizes too image_size = 32 # You can lower your batch size if you're running out of GPU memory batch_size = 64 # Define data augmentations preprocess = transforms.Compose( [ transforms.Resize((image_size, image_size)), # Resize transforms.RandomHorizontalFlip(), # Randomly flip (data augmentation) transforms.ToTensor(), # Convert to tensor (0, 1) transforms.Normalize([0.5], [0.5]), # Map to (-1, 1) ] ) def transform(examples): images = [preprocess(image.convert("RGB")) for image in examples["image"]] return {"images": images} dataset.set_transform(transform) # Create a dataloader from the dataset to serve up the transformed images in batches train_dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=True )
2、定义Schedulers,以在训练和推理过程中添加噪声,可以控制在不同时间步添加噪声:
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
3、添加噪声:
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)
可以看到随着时间步的增加,增添的噪声越来越多。
4、定义模型,这里采用U-Net,模型如下:
原理:先从输入图像经过几个ResNet的下采样块,再通过同样的块进行上采样。下采样的输出也和上采样之间通过跳级进行连接。
采用Diffusers里面的UNet2DModel:
from diffusers import UNet2DModel # Create a model model = UNet2DModel( sample_size=image_size, # the target image resolution in_channels=3, # the number of input channels, 3 for RGB images out_channels=3, # the number of output channels layers_per_block=2, # how many ResNet layers to use per UNet block block_out_channels=(64, 128, 128, 256), # More channels -> more parameters down_block_types=( "DownBlock2D", # a regular ResNet downsampling block "DownBlock2D", "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention "AttnDownBlock2D", ), up_block_types=( "AttnUpBlock2D", "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention "UpBlock2D", "UpBlock2D", # a regular ResNet upsampling block ), ) model.to(device);
5、训练:
# Set the noise scheduler noise_scheduler = DDPMScheduler( num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2" ) # Training loop optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4) losses = [] for epoch in range(30): for step, batch in enumerate(train_dataloader): clean_images = batch["images"].to(device) # Sample noise to add to the images noise = torch.randn(clean_images.shape).to(clean_images.device) bs = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device ).long() # Add noise to the clean images according to the noise magnitude at each timestep noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) # Get the model prediction noise_pred = model(noisy_images, timesteps, return_dict=False)[0] # Calculate the loss loss = F.mse_loss(noise_pred, noise) loss.backward(loss) losses.append(loss.item()) # Update the model parameters with the optimizer optimizer.step() optimizer.zero_grad() if (epoch + 1) % 5 == 0: loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader) print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}") Epoch:5, loss: 0.16273280512541533 Epoch:10, loss: 0.11161588924005628 Epoch:15, loss: 0.10206522420048714 Epoch:20, loss: 0.08302505919709802 Epoch:25, loss: 0.07805309211835265 Epoch:30, loss: 0.07474562455900013
绘制损失曲线:
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))
plt.show()
5、生成图像:
from diffusers import DDPMPipeline
image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)
pipeline_output = image_pipe()
pipeline_output.images[0]
编写采样循环:
# Random starting point (8 random images):
sample = torch.randn(8, 3, 32, 32).to(device)
for i, t in enumerate(noise_scheduler.timesteps):
# Get model pred
with torch.no_grad():
residual = model(sample, t).sample
# Update sample with step
sample = noise_scheduler.step(residual, t, sample).prev_sample
show_images(sample)
其中step函数用来更新去噪图像。由此可见预测过程也是通过每一步预测出噪声再逐步更新最后得到预测图像。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。