赞
踩
EDM (Euler Discretization with Momentum): EDM代表了欧拉离散化并带有动量的方法,它通常是对连续时间扩散过程进行数值积分的一种变体,通过引入动量项来改进收敛性和稳定性。
在 EDMSampler 中,主要原理可以概括为以下几点:
决定能否看懂代码的重点!!!
euler_step = self.euler_step(x, d, dt) 这段代码是在求解微分方程
x
是当前状态变量,即带有某个噪声级别的样本。d
通常表示在这个噪声级别下对样本进行去噪操作后的结果与原噪声样本之间的差异。dt
是时间步长或者说噪声水平的变化量,即 next_sigma - sigma_hat
。说得简单点就是,花个坐标轴x,y; 现在x轴上某个点的值已知(当前带有噪声的样本),现在x轴方向变化了dt也已知,导致了y轴的变化量d也已知,求斜率(斜率即微分方程的解)
- class EDMSampler(SingleStepDiffusionSampler):
- def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- self.s_churn = s_churn
- self.s_tmin = s_tmin
- self.s_tmax = s_tmax
- self.s_noise = s_noise
-
- def sampler_step(self, sigma, next_sigma, model, x, cond, uc=None, gamma=0.0, **kwargs):
- sigma_hat = sigma * (gamma + 1.0)
- if gamma > 0:
- eps = Tensor(np.random.randn(*x.shape), x.dtype) * self.s_noise
- x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
-
- denoised = self.denoise(x, model, sigma_hat, cond, uc, **kwargs)
- d = to_d(x, sigma_hat, denoised)
- dt = append_dims(next_sigma - sigma_hat, x.ndim)
-
- euler_step = self.euler_step(x, d, dt) #核心,在解微分方程
- x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, model, cond, uc)
- return x
-
- def __call__(self, model, x, cond, uc=None, num_steps=None, **kwargs):
- x = ops.cast(x, ms.float32)
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
-
- for i in self.get_sigma_gen(num_sigmas):
- gamma = (
- min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
- )
- x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], model, x, cond, uc, gamma, **kwargs)
-
- return x
-
-
-
-
- #只想搞懂原理的话,下面的依赖可以不看
-
-
- class SingleStepDiffusionSampler(BaseDiffusionSampler):
- def sampler_step(self, sigma, next_sigma, model, x, cond, uc=None, gamma=0.0, **kwargs):
- sigma_hat = sigma * (gamma + 1.0)
- if gamma > 0:
- eps = Tensor(np.random.randn(*x.shape), x.dtype) * self.s_noise
- x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
-
- denoised = self.denoise(x, model, sigma_hat, cond, uc, **kwargs)
- d = to_d(x, sigma_hat, denoised)
- dt = append_dims(next_sigma - sigma_hat, x.ndim)
-
- euler_step = self.euler_step(x, d, dt)
- x = euler_step
- return x
-
- def euler_step(self, x, d, dt):
- return x + dt * d
-
-
- class BaseDiffusionSampler:
- def __init__(
- self,
- discretization_config: Union[Dict, ListConfig, OmegaConf],
- num_steps: Union[int, None] = None,
- guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
- verbose: bool = False,
- ):
- self.num_steps = num_steps
- self.discretization = instantiate_from_config(discretization_config)
- self.guider = instantiate_from_config(
- default(
- guider_config,
- DEFAULT_GUIDER,
- )
- )
- self.verbose = verbose
-
- def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
- sigmas = self.discretization(self.num_steps if num_steps is None else num_steps)
-
- uc = default(uc, cond)
-
- x *= Tensor(np.sqrt(1.0 + sigmas[0] ** 2.0), x.dtype)
- num_sigmas = len(sigmas)
-
- s_in = ops.ones((x.shape[0],), x.dtype)
-
- return x, s_in, sigmas, num_sigmas, cond, uc
-
- def denoise(self, x, model, sigma, cond, uc, **kwargs):
- noised_input, sigmas, cond = self.guider.prepare_inputs(x, sigma, cond, uc)
- cond = model.openai_input_warpper(cond)
- c_skip, c_out, c_in, c_noise = model.denoiser(sigmas, noised_input.ndim)
- model_output = model.model(noised_input * c_in, c_noise, **cond, **kwargs)
- model_output = model_output.astype(ms.float32)
- denoised = model_output * c_out + noised_input * c_skip
- denoised = self.guider(denoised, sigma)
- return denoised
-
- def get_sigma_gen(self, num_sigmas):
- sigma_generator = range(num_sigmas - 1)
- if self.verbose:
- print("#" * 30, " Sampling setting ", "#" * 30)
- print(f"Sampler: {self.__class__.__name__}")
- print(f"Discretization: {self.discretization.__class__.__name__}")
- print(f"Guider: {self.guider.__class__.__name__}")
- sigma_generator = tqdm(
- sigma_generator,
- total=(num_sigmas - 1),
- desc=f"Sampling with {self.__class__.__name__} for {(num_sigmas - 1)} steps",
- )
- return sigma_generator
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。