diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index 01417e7..a7490dc 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -1,4 +1,4 @@ -from torch import Generator, Tensor, arange, device as Device, randn, tensor +from torch import Tensor, arange, device as Device from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler @@ -30,54 +30,5 @@ class DDPM(Scheduler): timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio return timesteps.flip(0) - def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: - """ - Generate the next step in the diffusion process. - - This method adjusts the input data using added noise and an estimate of the denoised data, based on the current - step in the diffusion process. This adjusted data forms the next step in the diffusion process. - - 1. It uses current and previous timesteps to calculate the current factor dictating the contribution of original - data and noise to the new step. - 2. An estimate of the denoised data (`estimated_denoised_data`) is generated. - 3. It calculates coefficients for the estimated denoised data and current data (`original_data_coeff` and - `current_data_coeff`) that balance their contribution to the denoised data for the next step. - 4. It calculates the denoised data for the next step (`denoised_x`), which is a combination of the estimated - denoised data and current data, adjusted by their respective coefficients. - 5. Noise is then added to `denoised_x`. The magnitude of noise is controlled by a calculated variance based on - the cumulative scaling factor and the current factor. - - The output is the new data step for the next stage in the diffusion process. - """ - timestep, previous_timestep = ( - self.timesteps[step], - ( - self.timesteps[step + 1] - if step < len(self.timesteps) - 1 - else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device) - ), - ) - current_cumulative_factor, previous_cumulative_scale_factor = ( - (self.scale_factors.cumprod(0))[timestep], - ( - (self.scale_factors.cumprod(0))[previous_timestep] - if step < len(self.timesteps) - 1 - else tensor(1, device=self.device) - ), - ) - current_factor = current_cumulative_factor / previous_cumulative_scale_factor - estimated_denoised_data = (x - (1 - current_cumulative_factor) ** 0.5 * noise) / current_cumulative_factor**0.5 - estimated_denoised_data = estimated_denoised_data.clamp(-1, 1) - original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / ( - 1 - current_cumulative_factor - ) - current_data_coeff = ( - current_factor**0.5 * (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor) - ) - denoised_x = original_data_coeff * estimated_denoised_data + current_data_coeff * x - if step < len(self.timesteps) - 1: - variance = (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor) * (1 - current_factor) - denoised_x = denoised_x + (variance.clamp(min=1e-20) ** 0.5) * randn( - x.shape, device=x.device, dtype=x.dtype, generator=generator - ) - return denoised_x + def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + raise NotImplementedError diff --git a/tests/foundationals/latent_diffusion/test_schedulers.py b/tests/foundationals/latent_diffusion/test_schedulers.py index 186d88b..fc94594 100644 --- a/tests/foundationals/latent_diffusion/test_schedulers.py +++ b/tests/foundationals/latent_diffusion/test_schedulers.py @@ -2,10 +2,20 @@ from typing import cast from warnings import warn import pytest -from torch import Tensor, allclose, device as Device, randn +from torch import Tensor, allclose, device as Device, equal, randn from refiners.fluxion import manual_seed -from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver +from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver + + +def test_ddpm_diffusers(): + from diffusers import DDPMScheduler # type: ignore + + diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012) + diffusers_scheduler.set_timesteps(1000) + refiners_scheduler = DDPM(num_inference_steps=1000) + + assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps) def test_dpm_solver_diffusers():