deprecate DDPM step which is unused for now

This commit is contained in:
Cédric Deltheil 2023-12-13 11:43:11 +01:00 committed by Cédric Deltheil
parent a7551e0392
commit 82a2aa1ec4
2 changed files with 15 additions and 54 deletions

View file

@ -1,4 +1,4 @@
from torch import Generator, Tensor, arange, device as Device, randn, tensor
from torch import Tensor, arange, device as Device
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
@ -30,54 +30,5 @@ class DDPM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Generate the next step in the diffusion process.
This method adjusts the input data using added noise and an estimate of the denoised data, based on the current
step in the diffusion process. This adjusted data forms the next step in the diffusion process.
1. It uses current and previous timesteps to calculate the current factor dictating the contribution of original
data and noise to the new step.
2. An estimate of the denoised data (`estimated_denoised_data`) is generated.
3. It calculates coefficients for the estimated denoised data and current data (`original_data_coeff` and
`current_data_coeff`) that balance their contribution to the denoised data for the next step.
4. It calculates the denoised data for the next step (`denoised_x`), which is a combination of the estimated
denoised data and current data, adjusted by their respective coefficients.
5. Noise is then added to `denoised_x`. The magnitude of noise is controlled by a calculated variance based on
the cumulative scaling factor and the current factor.
The output is the new data step for the next stage in the diffusion process.
"""
timestep, previous_timestep = (
self.timesteps[step],
(
self.timesteps[step + 1]
if step < len(self.timesteps) - 1
else tensor(-(self.num_train_timesteps // self.num_inference_steps), device=self.device)
),
)
current_cumulative_factor, previous_cumulative_scale_factor = (
(self.scale_factors.cumprod(0))[timestep],
(
(self.scale_factors.cumprod(0))[previous_timestep]
if step < len(self.timesteps) - 1
else tensor(1, device=self.device)
),
)
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
estimated_denoised_data = (x - (1 - current_cumulative_factor) ** 0.5 * noise) / current_cumulative_factor**0.5
estimated_denoised_data = estimated_denoised_data.clamp(-1, 1)
original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / (
1 - current_cumulative_factor
)
current_data_coeff = (
current_factor**0.5 * (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor)
)
denoised_x = original_data_coeff * estimated_denoised_data + current_data_coeff * x
if step < len(self.timesteps) - 1:
variance = (1 - previous_cumulative_scale_factor) / (1 - current_cumulative_factor) * (1 - current_factor)
denoised_x = denoised_x + (variance.clamp(min=1e-20) ** 0.5) * randn(
x.shape, device=x.device, dtype=x.dtype, generator=generator
)
return denoised_x
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
raise NotImplementedError

View file

@ -2,10 +2,20 @@ from typing import cast
from warnings import warn
import pytest
from torch import Tensor, allclose, device as Device, randn
from torch import Tensor, allclose, device as Device, equal, randn
from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver
def test_ddpm_diffusers():
from diffusers import DDPMScheduler # type: ignore
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(1000)
refiners_scheduler = DDPM(num_inference_steps=1000)
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
def test_dpm_solver_diffusers():