赞
踩
代码解析:正向扩散过程和加噪演示
引言
这段代码实现了一个正向扩散过程和加噪演示的功能。通过生成一个特定形状的数据集,并在每个时间步长上应用正向扩散过程和加噪过程,最终展示了数据点在空间中的演变过程。
数据集生成
通过 make_swiss_roll 函数生成一个类似瑞士卷的数据集,数据集具有特定的形状和噪声。在这个示例中,数据集被缩放和裁剪,以便更好地展示正向扩散和加噪的效果。
超参数设定
设定了一系列超参数,包括时间步数 num_steps 和用于控制正向扩散过程的 alphas 和 betas。这些超参数决定了正向扩散过程中的权重变化,并影响数据点在空间中的演变轨迹。
正向扩散过程
定义了一个函数 q_x,用于执行正向扩散过程。该函数接受初始数据点和时间步长作为输入,并根据预先设定的超参数计算出新的数据点。在每个时间步长上,根据权重 alphas 和 betas,将初始数据点与噪声相结合,生成新的数据点。
加噪演示
通过循环迭代,每隔一定的时间步长,在图表中展示了数据点的演变过程。在每个演示步骤中,通过调用 q_x 函数生成新的数据点,并在图表中以散点图的形式展示。这样可以清晰地观察到数据点在空间中的变化,从而更好地理解加噪的效果。
结论
这段代码展示了如何使用正向扩散过程和加噪过程来生成和演示数据集的变化。通过调整超参数和观察结果,可以更好地理解数据的分布和特征,为后续的数据分析和建模工作提供参考。
import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import make_swiss_roll # 导入 make_swiss_roll 函数 # 构建我们需要的数据集 s_curve, _ = make_swiss_roll(10**4, noise=0.1) s_curve = s_curve[:, [0, 2]] / 10.0 dataset = torch.Tensor(s_curve).float() # 确定时间步数 num_steps = 100 # 确定alpha、beta超参数的值 betas = torch.linspace(-6, 6, num_steps) betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5 alphas = 1 - betas alphas_prod = torch.cumprod(alphas, 0) alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0) alphas_bar_sqrt = torch.sqrt(alphas_prod) one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod) # 正向扩散过程——根据x_0和noise计算出任意时刻的x_t值 def q_x(x_0, t): noise = torch.randn_like(x_0) alphas_t = alphas_bar_sqrt[t] alphas_1_m_t = one_minus_alphas_bar_sqrt[t] return (alphas_t * x_0 + alphas_1_m_t * noise) # 演示加噪过程,每20步展示一次结果 num_shows = 20 fig, axs = plt.subplots(2, 10, figsize=(28, 3)) for i in range(num_shows): j = i // 10 k = i % 10 q_i = q_x(dataset, torch.tensor([i * num_steps // num_shows])) axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white') axs[j, k].set_axis_off() axs[j, k].set_title(f'$q(\\mathbf{{x}}_{{{i * num_steps // num_shows}}})$') plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。