mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
feature: Euler scheduler (#138)
This commit is contained in:
parent
ff5ec74e05
commit
8423c5efa7
|
@ -11,7 +11,6 @@ from refiners.foundationals.latent_diffusion.schedulers.scheduler import Schedul
|
||||||
|
|
||||||
T = TypeVar("T", bound="fl.Module")
|
T = TypeVar("T", bound="fl.Module")
|
||||||
|
|
||||||
|
|
||||||
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,6 +90,8 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||||
|
|
||||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||||
|
# scale latents for schedulers that need it
|
||||||
|
latents = self.scheduler.scale_model_input(latents, step=step)
|
||||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||||
|
|
||||||
# classifier-free guidance
|
# classifier-free guidance
|
||||||
|
|
|
@ -2,10 +2,6 @@ from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
|
||||||
"Scheduler",
|
|
||||||
"DPMSolver",
|
|
||||||
"DDPM",
|
|
||||||
"DDIM",
|
|
||||||
]
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor, Generator
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class DDIM(Scheduler):
|
||||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
||||||
return timesteps.flip(0)
|
return timesteps.flip(0)
|
||||||
|
|
||||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
timestep, previous_timestep = (
|
timestep, previous_timestep = (
|
||||||
self.timesteps[step],
|
self.timesteps[step],
|
||||||
(
|
(
|
||||||
|
@ -43,13 +43,10 @@ class DDIM(Scheduler):
|
||||||
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
current_scale_factor, previous_scale_factor = (
|
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
|
||||||
self.cumulative_scale_factors[timestep],
|
|
||||||
(
|
|
||||||
self.cumulative_scale_factors[previous_timestep]
|
self.cumulative_scale_factors[previous_timestep]
|
||||||
if previous_timestep > 0
|
if previous_timestep > 0
|
||||||
else self.cumulative_scale_factors[0]
|
else self.cumulative_scale_factors[0]
|
||||||
),
|
|
||||||
)
|
)
|
||||||
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
||||||
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
|
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||||
|
import numpy as np
|
||||||
|
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype, Generator
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
class DPMSolver(Scheduler):
|
class DPMSolver(Scheduler):
|
||||||
|
@ -90,12 +88,7 @@ class DPMSolver(Scheduler):
|
||||||
)
|
)
|
||||||
return denoised_x
|
return denoised_x
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
noise: Tensor,
|
|
||||||
step: int,
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
"""
|
||||||
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||||
|
from torch import Tensor, device as Device, dtype as Dtype, float32, tensor, Generator
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class EulerScheduler(Scheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: Dtype = float32,
|
||||||
|
):
|
||||||
|
if noise_schedule != NoiseSchedule.QUADRATIC:
|
||||||
|
raise NotImplementedError
|
||||||
|
super().__init__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
noise_schedule=noise_schedule,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.sigmas = self._generate_sigmas()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_noise_sigma(self) -> Tensor:
|
||||||
|
return self.sigmas.max()
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
# We need to use numpy here because:
|
||||||
|
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||||
|
# torch.linspace(0,999,31)[15] is 499.5
|
||||||
|
# ...and we want the same result as the original codebase.
|
||||||
|
timesteps = torch.tensor(
|
||||||
|
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
|
||||||
|
).flip(0)
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
def _generate_sigmas(self) -> Tensor:
|
||||||
|
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||||
|
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
||||||
|
sigmas = torch.cat([sigmas, tensor([0.0])])
|
||||||
|
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
sigma = self.sigmas[step]
|
||||||
|
return x / ((sigma**2 + 1) ** 0.5)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
noise: Tensor,
|
||||||
|
step: int,
|
||||||
|
generator: Generator | None = None,
|
||||||
|
s_churn: float = 0.0,
|
||||||
|
s_tmin: float = 0.0,
|
||||||
|
s_tmax: float = float("inf"),
|
||||||
|
s_noise: float = 1.0,
|
||||||
|
) -> Tensor:
|
||||||
|
sigma = self.sigmas[step]
|
||||||
|
|
||||||
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
|
||||||
|
|
||||||
|
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
|
||||||
|
eps = alt_noise * s_noise
|
||||||
|
sigma_hat = sigma * (gamma + 1)
|
||||||
|
if gamma > 0:
|
||||||
|
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||||
|
|
||||||
|
predicted_x = x - sigma_hat * noise
|
||||||
|
|
||||||
|
# 1st order Euler
|
||||||
|
derivative = (x - predicted_x) / sigma_hat
|
||||||
|
dt = self.sigmas[step + 1] - sigma_hat
|
||||||
|
denoised_x = x + derivative * dt
|
||||||
|
|
||||||
|
return denoised_x
|
|
@ -1,9 +1,8 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log, Generator
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
|
||||||
|
|
||||||
T = TypeVar("T", bound="Scheduler")
|
T = TypeVar("T", bound="Scheduler")
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,7 +49,7 @@ class Scheduler(ABC):
|
||||||
self.timesteps = self._generate_timesteps()
|
self.timesteps = self._generate_timesteps()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
|
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
|
||||||
|
|
||||||
|
@ -71,6 +70,12 @@ class Scheduler(ABC):
|
||||||
def steps(self) -> list[int]:
|
def steps(self) -> list[int]:
|
||||||
return list(range(self.num_inference_steps))
|
return list(range(self.num_inference_steps))
|
||||||
|
|
||||||
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
For compatibility with schedulers that need to scale the input according to the current timestep.
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||||
return (
|
return (
|
||||||
linspace(
|
linspace(
|
||||||
|
|
|
@ -2,10 +2,10 @@ from typing import cast
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from torch import Tensor, allclose, device as Device, equal, randn
|
from torch import Tensor, allclose, device as Device, equal, randn, isclose
|
||||||
|
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver
|
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler
|
||||||
|
|
||||||
|
|
||||||
def test_ddpm_diffusers():
|
def test_ddpm_diffusers():
|
||||||
|
@ -63,6 +63,34 @@ def test_ddim_diffusers():
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_euler_diffusers():
|
||||||
|
from diffusers import EulerDiscreteScheduler
|
||||||
|
|
||||||
|
manual_seed(0)
|
||||||
|
diffusers_scheduler = EulerDiscreteScheduler(
|
||||||
|
beta_end=0.012,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
beta_start=0.00085,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
steps_offset=1,
|
||||||
|
timestep_spacing="linspace",
|
||||||
|
use_karras_sigmas=False,
|
||||||
|
)
|
||||||
|
diffusers_scheduler.set_timesteps(30)
|
||||||
|
refiners_scheduler = EulerScheduler(num_inference_steps=30)
|
||||||
|
|
||||||
|
sample = randn(1, 4, 32, 32)
|
||||||
|
noise = randn(1, 4, 32, 32)
|
||||||
|
|
||||||
|
assert isclose(diffusers_scheduler.init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"
|
||||||
|
|
||||||
|
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||||
|
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
|
||||||
|
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
|
||||||
|
|
||||||
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_remove_noise():
|
def test_scheduler_remove_noise():
|
||||||
from diffusers import DDIMScheduler # type: ignore
|
from diffusers import DDIMScheduler # type: ignore
|
||||||
|
|
||||||
|
@ -84,7 +112,9 @@ def test_scheduler_remove_noise():
|
||||||
noise = randn(1, 4, 32, 32)
|
noise = randn(1, 4, 32, 32)
|
||||||
|
|
||||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
|
diffusers_output = cast(
|
||||||
|
Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample
|
||||||
|
) # type: ignore
|
||||||
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)
|
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)
|
||||||
|
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
Loading…
Reference in a new issue