scheduler: add remove noise

aka original sample prediction (or predict x0)

E.g. useful for methods like self-attention guidance (see equation (2)
in https://arxiv.org/pdf/2210.00939.pdf)
This commit is contained in:
Cédric Deltheil 2023-10-05 16:44:38 +02:00 committed by Cédric Deltheil
parent 665bcdc95c
commit 7d2abf6fbc
2 changed files with 36 additions and 2 deletions

View file

@ -78,11 +78,20 @@ class Scheduler(ABC):
step: int,
) -> Tensor:
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep].unsqueeze(-1).unsqueeze(-1)
noise_stds = self.noise_std[timestep].unsqueeze(-1).unsqueeze(-1)
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
if device is not None:
self.device = Device(device)

View file

@ -71,6 +71,31 @@ def test_ddim_solver_diffusers():
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_scheduler_remove_noise():
from diffusers import DDIMScheduler # type: ignore
diffusers_scheduler = DDIMScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085,
num_train_timesteps=1000,
set_alpha_to_one=False,
steps_offset=1,
clip_sample=False,
)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = DDIM(num_inference_steps=30)
sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_scheduler_device(test_device: Device):
if test_device.type == "cpu":
warn("not running on CPU, skipping")