当前位置:   article > 正文

扩散模型学习——代码学习_扩散模型代码

扩散模型代码

引言

  • 这是第一次接触扩散模型,为了学习,这里好好分析一下他的代码

正文

UNet网络结构

  • 这部分主要是定义一下网络结构,以及相关的网络超参数
  • 具体网络结构的图片如下

在这里插入图片描述
下述为网络结构各个层的定义

  • 结合定义和模型的具体输出,会更加理解
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):  # cfeat - context features
        super(ContextUnet, self).__init__()

        # number of input channels, number of intermediate feature maps and number of classes
        # 输入通道数
        self.in_channels = in_channels
        # 映射特征数量
        self.n_feat = n_feat
        # 生成类别数
        self.n_cfeat = n_cfeat
        # 生成的是方形图,并且输入必须能够被4整除
        self.h = height  #assume h == w. must be divisible by 4, so 28,24,20,16...

        # Initialize the initial convolutional layer
        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        # 初始化下采样层
        self.down1 = UnetDown(n_feat, n_feat)        # down1 #[10, 256, 8, 8]
        self.down2 = UnetDown(n_feat, 2 * n_feat)    # down2 #[10, 256, 4,  4]
        
         # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
        # 仅仅进行平均池化,并没有改变他的通道数
        self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())

        # Embed the timestep and context labels with a one-layer fully connected neural network
        # 定义两个嵌入层,将时间戳信息和上下文消息都转为对应的embedding向量
        # 这里仅仅是改变通道数,并没有改变上下文信息的特征
        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
        self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)

        # Initialize the up-sampling path of the U-Net with three levels
        # 并不改变通道数,仅仅是进行上采样
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample  
            nn.GroupNorm(8, 2 * n_feat), # normalize                       
            nn.ReLU(),
        )
        
        # 降低通道数,并进行上采样,同下
        self.up1 = UnetUp(4 * n_feat, n_feat)
        # 降低通道数,并进行上采样,这里输入通道和up1的输出通道不同,是因为还有上下文信息和之前下采样的输出
        self.up2 = UnetUp(2 * n_feat, n_feat)

        # 初始化最终的卷积层,将最终的输出映射为和输入相同大小
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps   #in_channels, out_channels, kernel_size, stride=1, padding=0
            nn.GroupNorm(8, n_feat), # normalize
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
        )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

网络结构的每一层参数如下

# 初始化卷积层
init_conv.conv1.0.weight torch.Size([256, 3, 3, 3])
init_conv.conv1.0.bias torch.Size([256])
init_conv.conv1.1.weight torch.Size([256])
init_conv.conv1.1.bias torch.Size([256])
init_conv.conv2.0.weight torch.Size([256, 256, 3, 3])
init_conv.conv2.0.bias torch.Size([256])
init_conv.conv2.1.weight torch.Size([256])
init_conv.conv2.1.bias torch.Size([256])

# 下采样层一
down1.model.0.conv1.0.weight torch.Size([256, 256, 3, 3])
down1.model.0.conv1.0.bias torch.Size([256])
down1.model.0.conv1.1.weight torch.Size([256])
down1.model.0.conv1.1.bias torch.Size([256])
down1.model.0.conv2.0.weight torch.Size([256, 256, 3, 3])
down1.model.0.conv2.0.bias torch.Size([256])
down1.model.0.conv2.1.weight torch.Size([256])
down1.model.0.conv2.1.bias torch.Size([256])
down1.model.1.conv1.0.weight torch.Size([256, 256, 3, 3])
down1.model.1.conv1.0.bias torch.Size([256])
down1.model.1.conv1.1.weight torch.Size([256])
down1.model.1.conv1.1.bias torch.Size([256])
down1.model.1.conv2.0.weight torch.Size([256, 256, 3, 3])
down1.model.1.conv2.0.bias torch.Size([256])
down1.model.1.conv2.1.weight torch.Size([256])
down1.model.1.conv2.1.bias torch.Size([256])

# 下采样层二
down2.model.0.conv1.0.weight torch.Size([512, 256, 3, 3])
down2.model.0.conv1.0.bias torch.Size([512])
down2.model.0.conv1.1.weight torch.Size([512])
down2.model.0.conv1.1.bias torch.Size([512])
down2.model.0.conv2.0.weight torch.Size([512, 512, 3, 3])
down2.model.0.conv2.0.bias torch.Size([512])
down2.model.0.conv2.1.weight torch.Size([512])
down2.model.0.conv2.1.bias torch.Size([512])
down2.model.1.conv1.0.weight torch.Size([512, 512, 3, 3])
down2.model.1.conv1.0.bias torch.Size([512])
down2.model.1.conv1.1.weight torch.Size([512])
down2.model.1.conv1.1.bias torch.Size([512])
down2.model.1.conv2.0.weight torch.Size([512, 512, 3, 3])
down2.model.1.conv2.0.bias torch.Size([512])
down2.model.1.conv2.1.weight torch.Size([512])
down2.model.1.conv2.1.bias torch.Size([512])

# 时间上下文信息embedding
timeembed1.model.0.weight torch.Size([512, 1])
timeembed1.model.0.bias torch.Size([512])
timeembed1.model.2.weight torch.Size([512, 512])
timeembed1.model.2.bias torch.Size([512])
timeembed2.model.0.weight torch.Size([256, 1])
timeembed2.model.0.bias torch.Size([256])
timeembed2.model.2.weight torch.Size([256, 256])
timeembed2.model.2.bias torch.Size([256])

# 上下文信息的embedding
contextembed1.model.0.weight torch.Size([512, 10])
contextembed1.model.0.bias torch.Size([512])
contextembed1.model.2.weight torch.Size([512, 512])
contextembed1.model.2.bias torch.Size([512])
contextembed2.model.0.weight torch.Size([256, 10])
contextembed2.model.0.bias torch.Size([256])
contextembed2.model.2.weight torch.Size([256, 256])
contextembed2.model.2.bias torch.Size([256])

# 上采样零层,如果不用加上上下文信息,这层完全没有必要,现在是加上了。
up0.0.weight torch.Size([512, 512, 7, 7])
up0.0.bias torch.Size([512])
up0.1.weight torch.Size([512])
up0.1.bias torch.Size([512])
up1.model.0.weight torch.Size([1024, 256, 2, 2])
up1.model.0.bias torch.Size([256])

# 上采样一层
up1.model.1.conv1.0.weight torch.Size([256, 256, 3, 3])
up1.model.1.conv1.0.bias torch.Size([256])
up1.model.1.conv1.1.weight torch.Size([256])
up1.model.1.conv1.1.bias torch.Size([256])
up1.model.1.conv2.0.weight torch.Size([256, 256, 3, 3])
up1.model.1.conv2.0.bias torch.Size([256])
up1.model.1.conv2.1.weight torch.Size([256])
up1.model.1.conv2.1.bias torch.Size([256])
up1.model.2.conv1.0.weight torch.Size([256, 256, 3, 3])
up1.model.2.conv1.0.bias torch.Size([256])
up1.model.2.conv1.1.weight torch.Size([256])
up1.model.2.conv1.1.bias torch.Size([256])
up1.model.2.conv2.0.weight torch.Size([256, 256, 3, 3])
up1.model.2.conv2.0.bias torch.Size([256])
up1.model.2.conv2.1.weight torch.Size([256])
up1.model.2.conv2.1.bias torch.Size([256])

# 上采样二层
up2.model.0.weight torch.Size([512, 256, 2, 2])
up2.model.0.bias torch.Size([256])
up2.model.1.conv1.0.weight torch.Size([256, 256, 3, 3])
up2.model.1.conv1.0.bias torch.Size([256])
up2.model.1.conv1.1.weight torch.Size([256])
up2.model.1.conv1.1.bias torch.Size([256])
up2.model.1.conv2.0.weight torch.Size([256, 256, 3, 3])
up2.model.1.conv2.0.bias torch.Size([256])
up2.model.1.conv2.1.weight torch.Size([256])
up2.model.1.conv2.1.bias torch.Size([256])
up2.model.2.conv1.0.weight torch.Size([256, 256, 3, 3])
up2.model.2.conv1.0.bias torch.Size([256])
up2.model.2.conv1.1.weight torch.Size([256])
up2.model.2.conv1.1.bias torch.Size([256])
up2.model.2.conv2.0.weight torch.Size([256, 256, 3, 3])
up2.model.2.conv2.0.bias torch.Size([256])
up2.model.2.conv2.1.weight torch.Size([256])
up2.model.2.conv2.1.bias torch.Size([256])

# 最终的输出层,将输出的通道进行调整为3
out.0.weight torch.Size([256, 512, 3, 3])
out.0.bias torch.Size([256])
out.1.weight torch.Size([256])
out.1.bias torch.Size([256])
out.3.weight torch.Size([3, 256, 3, 3])
out.3.bias torch.Size([3])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119

当前网络每一层输出的张量情况

# 输入的图片为[32,3,28,28]=[batch_size,channel,height,width]
# 提取特征,扩充通道数
Layer: ResidualConvBlock
Input shape: torch.Size([32, 3, 28, 28])
Output shape: torch.Size([32, 64, 28, 28])
==============================
# 下采样层一:尺寸减半,通道数不变
Layer: UnetDown
Input shape: torch.Size([32, 64, 28, 28])
Output shape: torch.Size([32, 64, 14, 14])
==============================
# 下采样层二:尺寸减半,通道数翻倍
Layer: UnetDown
Input shape: torch.Size([32, 64, 14, 14])
Output shape: torch.Size([32, 128, 7, 7])
==============================
# 还是对输入的特征图进行下采样,是4*4的方格进行下采样
Layer: Sequential
Input shape: torch.Size([32, 128, 7, 7])
Output shape: torch.Size([32, 128, 1, 1])
==============================

# 下述四层为上下文信息处理层,分别处理上下文类别信息和时间序列信息,分层加入到模型中
# 下述为特征上下文信息,每一个样本都有自己的特征上下文
Layer: EmbedFC
Input shape: torch.Size([32, 5])
Output shape: torch.Size([32, 128])
==============================
# 下述为时间序列上下文,所有样本的时间序列是统一的
Layer: EmbedFC
Input shape: torch.Size([1, 1, 1, 1])
Output shape: torch.Size([1, 128])
==============================
# 下述为经过扩展的样本上下文,用于加到第二个上采样层
Layer: EmbedFC
Input shape: torch.Size([32, 5])
Output shape: torch.Size([32, 64])
==============================
# 下述为经过扩展的时间序列信息,用于加到第二个上采样层
Layer: EmbedFC
Input shape: torch.Size([1, 1, 1, 1])
Output shape: torch.Size([1, 64])
==============================

# 上采样层零:扩展维度,对应两个下采样层下的第一个卷积层
Layer: Sequential
Input shape: torch.Size([32, 128, 1, 1])
Output shape: torch.Size([32, 128, 7, 7])
==============================
# 上采样层一
Layer: UnetUp
Input shape: torch.Size([32, 128, 7, 7])
Output shape: torch.Size([32, 64, 14, 14])
==============================
# 上采样层二
Layer: UnetUp
Input shape: torch.Size([32, 64, 14, 14])
Output shape: torch.Size([32, 64, 28, 28])
==============================
# 输出调整层,将输出的信道调整为原始图层
Layer: Sequential
Input shape: torch.Size([32, 128, 28, 28])
Output shape: torch.Size([32, 3, 28, 28])
==============================
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64

网络各层的连接方式

  • 这里最好对照着图片看,会更加清晰,知道他这个网络模型的各个层级之间如何记性沟通。
  • 整体来说,下采样比较简单,上采样比较复杂,因为涉及到添加对应下采样层的输出还有上下文信息、时间序列信息等,所以需要好好看看。
  • 不过可以学到,如何添加额外信息的
 def forward(self, x, t, c=None):
        """
        x : (batch, n_feat, h, w) : input image
        t : (batch, n_cfeat)      : time step
        c : (batch, n_classes)    : context label
        """
        # x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on
        
        '''下采样过程'''
        # 将输入的图片传入初始化卷积层中
        x = self.init_conv(x)
        # 将结果传入下采样层
        down1 = self.down1(x)       #[10, 256, 8, 8]
        down2 = self.down2(down1)   #[10, 256, 4, 4]
        
        # 将特征映射为向量
        hiddenvec = self.to_vec(down2)
        
        '''上采样过程'''
        # mask out context if context_mask == 1
        # 判定是否有上下文信息
        if c is None:
            c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
            
        # 将上下文信息context information还有timestep转为embedding
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)     # (batch, 2*n_feat, 1,1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
        #print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")

        
        # 上采样过程,分别和对应下采样对应层和对应上下文信息加入到每一个上采样层中
        up1 = self.up0(hiddenvec)
        up2 = self.up1(cemb1*up1 + temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2 + temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

训练方法

  • 这里需要明白训练公式,通过公式推导,书写代码,需要明确如下参数
    • α ‾ \overline{\alpha} α 表示若干个 α t \alpha_t αt的连乘
    • ξ θ \xi_\theta ξθ 表示预测的噪声,另外一个表示实际生成的噪声
      在这里插入图片描述
      下述为定义增加噪声的过程
# helper function: perturbs an image to a specified noise level
def perturb_input(x, t, noise):
    # 前向传播公示
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise
  • 1
  • 2
  • 3
  • 4

下述为具体的训练代码

# training without context code

# set into train mode
nn_model.train()

for ep in range(n_epoch):
    print(f'epoch {ep}')
    
    # linearly decay learning rate
    # 定义学习率进行线性衰减
    optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
    
    # 加载进度条
    pbar = tqdm(dataloader, mininterval=2 )
    for x, _ in pbar:   # x: images
        optim.zero_grad()
        x = x.to(device)
        
        # perturb data
        # 给当前的图片增加噪声
        noise = torch.randn_like(x) # 随机生成噪声
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) # 随机生成timestep
        x_pert = perturb_input(x, t, noise) # 增加噪声扰动
        
        # use network to recover noise
        # 使用网络去预测噪声
        pred_noise = nn_model(x_pert, t / timesteps)
        
        # loss is mean squared error between the predicted and true noise
        # 使用MSE计算损失
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()
        
        optim.step()

    # save model periodically
    # 按照周期保存模型
    if ep%4==0 or ep == int(n_epoch-1):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth")
        print('saved model at ' + save_dir + f"model_{ep}.pth")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

DDPM采样方法讲解

  • 在这个基础的扩散模型中,最为重要的是denoise_add_noise方法,该方法主要是先如下功能
    • 生成model预测的噪声,从原来数据中减去模型预测的噪声
    • 添加新的额外的噪声,防止训练崩溃
  • 这里的采样方法完全是按照公式进行展开的,重要的是几个参数的构建方法
    • 下属方法中的a_t是公式中的 α t \sqrt\alpha_t α t,

在这里插入图片描述

# construct DDPM noise schedule
# 构建DDPM的计算模式
# 定义 \beta_t  ,表示从零到一的若干均匀分布的小数,有几个时间步骤,就有几个
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1

# 计算\alpha_t 得值
a_t = 1 - b_t
# 这里是通过取对数,然后再去指数,来避免小数连乘的溢出。
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
# 确保x_0的连续性
ab_t[0] = 1

# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
# 祛除模型预测的噪声,并且添加一些额外的噪声,避免过拟合
def denoise_add_noise(x, t, pred_noise, z=None):

    # 重参数化,实现对特定复杂分布的采样,z是从高斯分布进行的正常采样
    if z is None:
        z = torch.randn_like(x)
    noise = b_t.sqrt()[t] * z
    
    # 公式的前半项,x是当前timestep的情况,这里完全是按照公式进行推倒的
    mean = ((x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) # 减去预测噪声
            / a_t[t].sqrt())
    
    # 增加额外的噪声,防止过拟合
    return mean + noise
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 上述方法完全是按照对应的公示进行展开的,看过了推导之后,发现对于整个公式的理解更加明确。

  • 下述为整体的采样过程

    • 对于每一张图片,都是多次迭代,并且逐步减去噪声
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, save_rate=20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # array to keep track of generated steps for plotting
    intermediate = [] 
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. For i = 1, don't add back in noise
        z = torch.randn_like(samples) if i > 1 else 0

        eps = nn_model(samples, t)    # predict noise e_(x_t,t)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate ==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

Context上下文信息添加

  • 关于上下文的添加,在之前的模型定义中ContextUNet是说明了上下文添加具体网络结构,这里就专门讲讲如何在采样过程中,增加对应的上下文信息
    • 就是在之前定义model的forward参数中增加了一个参数c
# sample with context using standard algorithm
@torch.no_grad()
def sample_ddpm_context(n_sample, context, save_rate=20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # array to keep track of generated steps for plotting
    intermediate = [] 
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. For i = 1, don't add back in noise
        z = torch.randn_like(samples) if i > 1 else 0
        
        # 和之前一样,就是增加了对应的上下文信息
        eps = nn_model(samples, t, c=context)    # predict noise e_(x_t,t, ctx)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

DDIM的方法详解

  • DDIM和DDPM二者在前向传播的过程中,是完全相同的,所以他们的模型定义是相同的,完全可以共用的。
  • 但是他们的采样过程是不同,DDIM能够实现跨步采样,速度更快,他是基于任意分布假设,并不是基于马卡洛夫链,所以不用逐步推理。具体算法描述如下

在这里插入图片描述
具体代码如下,下述要结合对应的采样公式,来实现对应的代码

# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()    
ab_t[0] = 1

# 下述为根据采样公式写出的采样函数
# t是当前的状态数量
# t-prev是根据当前状态t,需要预测prev向前的内容
def denoise_ddim(x, t, t_prev, pred_noise):
    ab = ab_t[t]
    ab_prev = ab_t[t_prev]
    
    x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
    dir_xt = (1 - ab_prev).sqrt() * pred_noise

    return x0_pred + dir_xt

# 具体调用采样过程
# sample quickly using DDIM
@torch.no_grad()
def sample_ddim(n_sample, n=20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # array to keep track of generated steps for plotting
    intermediate = [] 
    step_size = timesteps // n
    for i in range(timesteps, 0, -step_size):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        eps = nn_model(samples, t)    # predict noise e_(x_t,t)
        samples = denoise_ddim(samples, i, i - step_size, eps)
        intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

总结

  • 之前的学习方式有点问题,在扩散模型这里就卡了差不多一周,看公式推导,看相关的代码,学习相关的数学推理,还没有将当前模块嵌入到对应的模型进行测试,效率被大大降低了,所以对于DDIM的学习就简单很多。

参考

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/369813
推荐阅读
相关标签
  

闽ICP备14008679号