feature: Euler scheduler (#138)

This commit is contained in:
Israfel Salazar 2024-01-10 11:32:40 +01:00 committed by GitHub
parent ff5ec74e05
commit 8423c5efa7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 138 additions and 33 deletions

View file

@ -11,7 +11,6 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Schedul
T = TypeVar("T", bound="fl.Module")
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
@ -91,6 +90,8 @@ class LatentDiffusionModel(fl.Module, ABC):
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for schedulers that need it
latents = self.scheduler.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance

View file

@ -2,10 +2,6 @@ from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
__all__ = [
"Scheduler",
"DPMSolver",
"DDPM",
"DDIM",
]
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]

View file

@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor, Generator
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
@ -34,7 +34,7 @@ class DDIM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
(
@ -43,13 +43,10 @@ class DDIM(Scheduler):
else tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = (
self.cumulative_scale_factors[timestep],
(
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise

View file

@ -1,9 +1,7 @@
from collections import deque
import numpy as np
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
import numpy as np
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype, Generator
from collections import deque
class DPMSolver(Scheduler):
@ -90,12 +88,7 @@ class DPMSolver(Scheduler):
)
return denoised_x
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.

View file

@ -0,0 +1,83 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from torch import Tensor, device as Device, dtype as Dtype, float32, tensor, Generator
import torch
import numpy as np
class EulerScheduler(Scheduler):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
if noise_schedule != NoiseSchedule.QUADRATIC:
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
self.sigmas = self._generate_sigmas()
@property
def init_noise_sigma(self) -> Tensor:
return self.sigmas.max()
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
timesteps = torch.tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
return timesteps
def _generate_sigmas(self) -> Tensor:
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5)
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
) -> Tensor:
sigma = self.sigmas[step]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
eps = alt_noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
predicted_x = x - sigma_hat * noise
# 1st order Euler
derivative = (x - predicted_x) / sigma_hat
dt = self.sigmas[step + 1] - sigma_hat
denoised_x = x + derivative * dt
return denoised_x

View file

@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
from enum import Enum
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log, Generator
from typing import TypeVar
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
T = TypeVar("T", bound="Scheduler")
@ -50,7 +49,7 @@ class Scheduler(ABC):
self.timesteps = self._generate_timesteps()
@abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
@ -71,6 +70,12 @@ class Scheduler(ABC):
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with schedulers that need to scale the input according to the current timestep.
"""
return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(

View file

@ -2,10 +2,10 @@ from typing import cast
from warnings import warn
import pytest
from torch import Tensor, allclose, device as Device, equal, randn
from torch import Tensor, allclose, device as Device, equal, randn, isclose
from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler
def test_ddpm_diffusers():
@ -63,6 +63,34 @@ def test_ddim_diffusers():
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_euler_diffusers():
from diffusers import EulerDiscreteScheduler
manual_seed(0)
diffusers_scheduler = EulerDiscreteScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085,
num_train_timesteps=1000,
steps_offset=1,
timestep_spacing="linspace",
use_karras_sigmas=False,
)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = EulerScheduler(num_inference_steps=30)
sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)
assert isclose(diffusers_scheduler.init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_scheduler_remove_noise():
from diffusers import DDIMScheduler # type: ignore
@ -84,7 +112,9 @@ def test_scheduler_remove_noise():
noise = randn(1, 4, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
diffusers_output = cast(
Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample
) # type: ignore
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"