mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
feature: Euler scheduler (#138)
This commit is contained in:
parent
ff5ec74e05
commit
8423c5efa7
|
@ -11,7 +11,6 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Schedul
|
|||
|
||||
T = TypeVar("T", bound="fl.Module")
|
||||
|
||||
|
||||
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||
|
||||
|
||||
|
@ -91,6 +90,8 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||
|
||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||
# scale latents for schedulers that need it
|
||||
latents = self.scheduler.scale_model_input(latents, step=step)
|
||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||
|
||||
# classifier-free guidance
|
||||
|
|
|
@ -2,10 +2,6 @@ from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
|||
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
|
||||
|
||||
__all__ = [
|
||||
"Scheduler",
|
||||
"DPMSolver",
|
||||
"DDPM",
|
||||
"DDIM",
|
||||
]
|
||||
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor, Generator
|
||||
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
|
||||
|
@ -34,7 +34,7 @@ class DDIM(Scheduler):
|
|||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
||||
return timesteps.flip(0)
|
||||
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
timestep, previous_timestep = (
|
||||
self.timesteps[step],
|
||||
(
|
||||
|
@ -43,13 +43,10 @@ class DDIM(Scheduler):
|
|||
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||
),
|
||||
)
|
||||
current_scale_factor, previous_scale_factor = (
|
||||
self.cumulative_scale_factors[timestep],
|
||||
(
|
||||
self.cumulative_scale_factors[previous_timestep]
|
||||
if previous_timestep > 0
|
||||
else self.cumulative_scale_factors[0]
|
||||
),
|
||||
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
|
||||
self.cumulative_scale_factors[previous_timestep]
|
||||
if previous_timestep > 0
|
||||
else self.cumulative_scale_factors[0]
|
||||
)
|
||||
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
||||
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
import numpy as np
|
||||
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype, Generator
|
||||
from collections import deque
|
||||
|
||||
|
||||
class DPMSolver(Scheduler):
|
||||
|
@ -90,12 +88,7 @@ class DPMSolver(Scheduler):
|
|||
)
|
||||
return denoised_x
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
noise: Tensor,
|
||||
step: int,
|
||||
) -> Tensor:
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""
|
||||
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
||||
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
from torch import Tensor, device as Device, dtype as Dtype, float32, tensor, Generator
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EulerScheduler(Scheduler):
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
if noise_schedule != NoiseSchedule.QUADRATIC:
|
||||
raise NotImplementedError
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.sigmas = self._generate_sigmas()
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self) -> Tensor:
|
||||
return self.sigmas.max()
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
# We need to use numpy here because:
|
||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
# torch.linspace(0,999,31)[15] is 499.5
|
||||
# ...and we want the same result as the original codebase.
|
||||
timesteps = torch.tensor(
|
||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
|
||||
).flip(0)
|
||||
return timesteps
|
||||
|
||||
def _generate_sigmas(self) -> Tensor:
|
||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
||||
sigmas = torch.cat([sigmas, tensor([0.0])])
|
||||
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
sigma = self.sigmas[step]
|
||||
return x / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
noise: Tensor,
|
||||
step: int,
|
||||
generator: Generator | None = None,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
) -> Tensor:
|
||||
sigma = self.sigmas[step]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
|
||||
|
||||
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
|
||||
eps = alt_noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
predicted_x = x - sigma_hat * noise
|
||||
|
||||
# 1st order Euler
|
||||
derivative = (x - predicted_x) / sigma_hat
|
||||
dt = self.sigmas[step + 1] - sigma_hat
|
||||
denoised_x = x + derivative * dt
|
||||
|
||||
return denoised_x
|
|
@ -1,9 +1,8 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log, Generator
|
||||
from typing import TypeVar
|
||||
|
||||
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
||||
|
||||
T = TypeVar("T", bound="Scheduler")
|
||||
|
||||
|
||||
|
@ -50,7 +49,7 @@ class Scheduler(ABC):
|
|||
self.timesteps = self._generate_timesteps()
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""
|
||||
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
|
||||
|
||||
|
@ -71,6 +70,12 @@ class Scheduler(ABC):
|
|||
def steps(self) -> list[int]:
|
||||
return list(range(self.num_inference_steps))
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
"""
|
||||
For compatibility with schedulers that need to scale the input according to the current timestep.
|
||||
"""
|
||||
return x
|
||||
|
||||
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||
return (
|
||||
linspace(
|
||||
|
|
|
@ -2,10 +2,10 @@ from typing import cast
|
|||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
from torch import Tensor, allclose, device as Device, equal, randn
|
||||
from torch import Tensor, allclose, device as Device, equal, randn, isclose
|
||||
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler
|
||||
|
||||
|
||||
def test_ddpm_diffusers():
|
||||
|
@ -63,6 +63,34 @@ def test_ddim_diffusers():
|
|||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_euler_diffusers():
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
manual_seed(0)
|
||||
diffusers_scheduler = EulerDiscreteScheduler(
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=0.00085,
|
||||
num_train_timesteps=1000,
|
||||
steps_offset=1,
|
||||
timestep_spacing="linspace",
|
||||
use_karras_sigmas=False,
|
||||
)
|
||||
diffusers_scheduler.set_timesteps(30)
|
||||
refiners_scheduler = EulerScheduler(num_inference_steps=30)
|
||||
|
||||
sample = randn(1, 4, 32, 32)
|
||||
noise = randn(1, 4, 32, 32)
|
||||
|
||||
assert isclose(diffusers_scheduler.init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"
|
||||
|
||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
|
||||
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
|
||||
|
||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_scheduler_remove_noise():
|
||||
from diffusers import DDIMScheduler # type: ignore
|
||||
|
||||
|
@ -84,7 +112,9 @@ def test_scheduler_remove_noise():
|
|||
noise = randn(1, 4, 32, 32)
|
||||
|
||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
|
||||
diffusers_output = cast(
|
||||
Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample
|
||||
) # type: ignore
|
||||
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)
|
||||
|
||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
|
Loading…
Reference in a new issue