mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
scheduler: add remove noise
aka original sample prediction (or predict x0) E.g. useful for methods like self-attention guidance (see equation (2) in https://arxiv.org/pdf/2210.00939.pdf)
This commit is contained in:
parent
665bcdc95c
commit
7d2abf6fbc
|
@ -78,11 +78,20 @@ class Scheduler(ABC):
|
||||||
step: int,
|
step: int,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
timestep = self.timesteps[step]
|
timestep = self.timesteps[step]
|
||||||
cumulative_scale_factors = self.cumulative_scale_factors[timestep].unsqueeze(-1).unsqueeze(-1)
|
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||||
noise_stds = self.noise_std[timestep].unsqueeze(-1).unsqueeze(-1)
|
noise_stds = self.noise_std[timestep]
|
||||||
noised_x = cumulative_scale_factors * x + noise_stds * noise
|
noised_x = cumulative_scale_factors * x + noise_stds * noise
|
||||||
return noised_x
|
return noised_x
|
||||||
|
|
||||||
|
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
timestep = self.timesteps[step]
|
||||||
|
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||||
|
noise_stds = self.noise_std[timestep]
|
||||||
|
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like
|
||||||
|
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
|
||||||
|
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
|
||||||
|
return denoised_x
|
||||||
|
|
||||||
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||||
if device is not None:
|
if device is not None:
|
||||||
self.device = Device(device)
|
self.device = Device(device)
|
||||||
|
|
|
@ -71,6 +71,31 @@ def test_ddim_solver_diffusers():
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_remove_noise():
|
||||||
|
from diffusers import DDIMScheduler # type: ignore
|
||||||
|
|
||||||
|
diffusers_scheduler = DDIMScheduler(
|
||||||
|
beta_end=0.012,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
beta_start=0.00085,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
set_alpha_to_one=False,
|
||||||
|
steps_offset=1,
|
||||||
|
clip_sample=False,
|
||||||
|
)
|
||||||
|
diffusers_scheduler.set_timesteps(30)
|
||||||
|
refiners_scheduler = DDIM(num_inference_steps=30)
|
||||||
|
|
||||||
|
sample = randn(1, 4, 32, 32)
|
||||||
|
noise = randn(1, 4, 32, 32)
|
||||||
|
|
||||||
|
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||||
|
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
|
||||||
|
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)
|
||||||
|
|
||||||
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_device(test_device: Device):
|
def test_scheduler_device(test_device: Device):
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
|
|
Loading…
Reference in a new issue