mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
refactor solvers to support different timestep spacings
This commit is contained in:
parent
d14c5bd5f8
commit
ddc1cf8ca7
|
@ -3,6 +3,6 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
||||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||||
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||||
|
|
||||||
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule"]
|
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||||
|
|
||||||
|
|
||||||
class DDIM(Solver):
|
class DDIM(Solver):
|
||||||
|
@ -13,6 +13,8 @@ class DDIM(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
|
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
|
||||||
|
timesteps_offset: int = 1,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
|
@ -25,6 +27,8 @@ class DDIM(Solver):
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps.
|
num_inference_steps: The number of inference steps.
|
||||||
num_train_timesteps: The number of training timesteps.
|
num_train_timesteps: The number of training timesteps.
|
||||||
|
timesteps_spacing: The spacing to use for the timesteps.
|
||||||
|
timesteps_offset: The offset to use for the timesteps.
|
||||||
initial_diffusion_rate: The initial diffusion rate.
|
initial_diffusion_rate: The initial diffusion rate.
|
||||||
final_diffusion_rate: The final diffusion rate.
|
final_diffusion_rate: The final diffusion rate.
|
||||||
noise_schedule: The noise schedule.
|
noise_schedule: The noise schedule.
|
||||||
|
@ -35,6 +39,8 @@ class DDIM(Solver):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
timesteps_spacing=timesteps_spacing,
|
||||||
|
timesteps_offset=timesteps_offset,
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
noise_schedule=noise_schedule,
|
noise_schedule=noise_schedule,
|
||||||
|
@ -43,16 +49,18 @@ class DDIM(Solver):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
|
||||||
"""
|
|
||||||
Generates decreasing timesteps with 'leading' spacing and offset of 1
|
|
||||||
similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5
|
|
||||||
"""
|
|
||||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
|
||||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
|
|
||||||
return timesteps.flip(0)
|
|
||||||
|
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
|
"""Apply one step of the backward diffusion process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor to apply the diffusion process to.
|
||||||
|
predicted_noise: The predicted noise tensor for the current step.
|
||||||
|
step: The current step of the diffusion process.
|
||||||
|
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The denoised version of the input data `x`.
|
||||||
|
"""
|
||||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||||
|
|
||||||
timestep, previous_timestep = (
|
timestep, previous_timestep = (
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from torch import Generator, Tensor, arange, device as Device
|
from torch import Generator, Tensor, device as Device
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
||||||
|
|
||||||
|
|
||||||
class DDPM(Solver):
|
class DDPM(Solver):
|
||||||
|
@ -17,6 +17,8 @@ class DDPM(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
|
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
|
||||||
|
timesteps_offset: int = 0,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
|
@ -27,6 +29,8 @@ class DDPM(Solver):
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps.
|
num_inference_steps: The number of inference steps.
|
||||||
num_train_timesteps: The number of training timesteps.
|
num_train_timesteps: The number of training timesteps.
|
||||||
|
timesteps_spacing: The spacing to use for the timesteps.
|
||||||
|
timesteps_offset: The offset to use for the timesteps.
|
||||||
initial_diffusion_rate: The initial diffusion rate.
|
initial_diffusion_rate: The initial diffusion rate.
|
||||||
final_diffusion_rate: The final diffusion rate.
|
final_diffusion_rate: The final diffusion rate.
|
||||||
first_inference_step: The first inference step.
|
first_inference_step: The first inference step.
|
||||||
|
@ -35,16 +39,13 @@ class DDPM(Solver):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
timesteps_spacing=timesteps_spacing,
|
||||||
|
timesteps_offset=timesteps_offset,
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
first_inference_step=first_inference_step,
|
first_inference_step=first_inference_step,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
|
||||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
|
||||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
|
|
||||||
return timesteps.flip(0)
|
|
||||||
|
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||||
|
|
||||||
|
|
||||||
class DPMSolver(Solver):
|
class DPMSolver(Solver):
|
||||||
|
@ -23,6 +22,8 @@ class DPMSolver(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
|
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT,
|
||||||
|
timesteps_offset: int = 0,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
last_step_first_order: bool = False,
|
last_step_first_order: bool = False,
|
||||||
|
@ -31,9 +32,26 @@ class DPMSolver(Solver):
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
):
|
):
|
||||||
|
"""Initializes a new DPM solver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps: The number of inference steps.
|
||||||
|
num_train_timesteps: The number of training timesteps.
|
||||||
|
timesteps_spacing: The spacing to use for the timesteps.
|
||||||
|
timesteps_offset: The offset to use for the timesteps.
|
||||||
|
initial_diffusion_rate: The initial diffusion rate.
|
||||||
|
final_diffusion_rate: The final diffusion rate.
|
||||||
|
last_step_first_order: Use a first-order update for the last step.
|
||||||
|
noise_schedule: The noise schedule.
|
||||||
|
first_inference_step: The first inference step.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
timesteps_spacing=timesteps_spacing,
|
||||||
|
timesteps_offset=timesteps_offset,
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
noise_schedule=noise_schedule,
|
noise_schedule=noise_schedule,
|
||||||
|
@ -44,21 +62,6 @@ class DPMSolver(Solver):
|
||||||
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||||
self.last_step_first_order = last_step_first_order
|
self.last_step_first_order = last_step_first_order
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
|
||||||
"""Generate the timesteps used by the solver.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
We need to use numpy here because:
|
|
||||||
|
|
||||||
- numpy.linspace(0,999,31)[15] is 499.49999999999994
|
|
||||||
- torch.linspace(0,999,31)[15] is 499.5
|
|
||||||
|
|
||||||
and we want the same result as the original codebase.
|
|
||||||
"""
|
|
||||||
return tensor(
|
|
||||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
|
||||||
).flip(0)
|
|
||||||
|
|
||||||
def rebuild(
|
def rebuild(
|
||||||
self: "DPMSolver",
|
self: "DPMSolver",
|
||||||
num_inference_steps: int | None,
|
num_inference_steps: int | None,
|
||||||
|
@ -148,10 +151,10 @@ class DPMSolver(Solver):
|
||||||
(ODEs).
|
(ODEs).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: The input data.
|
x: The input tensor to apply the diffusion process to.
|
||||||
predicted_noise: The predicted noise.
|
predicted_noise: The predicted noise tensor for the current step.
|
||||||
step: The current step.
|
step: The current step of the diffusion process.
|
||||||
generator: The random number generator.
|
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The denoised version of the input data `x`.
|
The denoised version of the input data `x`.
|
||||||
|
|
|
@ -2,7 +2,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||||
|
|
||||||
|
|
||||||
class Euler(Solver):
|
class Euler(Solver):
|
||||||
|
@ -16,6 +16,8 @@ class Euler(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
|
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
|
||||||
|
timesteps_offset: int = 0,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
|
@ -28,6 +30,8 @@ class Euler(Solver):
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps.
|
num_inference_steps: The number of inference steps.
|
||||||
num_train_timesteps: The number of training timesteps.
|
num_train_timesteps: The number of training timesteps.
|
||||||
|
timesteps_spacing: The spacing to use for the timesteps.
|
||||||
|
timesteps_offset: The offset to use for the timesteps.
|
||||||
initial_diffusion_rate: The initial diffusion rate.
|
initial_diffusion_rate: The initial diffusion rate.
|
||||||
final_diffusion_rate: The final diffusion rate.
|
final_diffusion_rate: The final diffusion rate.
|
||||||
noise_schedule: The noise schedule.
|
noise_schedule: The noise schedule.
|
||||||
|
@ -40,6 +44,8 @@ class Euler(Solver):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
timesteps_spacing=timesteps_spacing,
|
||||||
|
timesteps_offset=timesteps_offset,
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
noise_schedule=noise_schedule,
|
noise_schedule=noise_schedule,
|
||||||
|
@ -54,20 +60,6 @@ class Euler(Solver):
|
||||||
"""The initial noise sigma."""
|
"""The initial noise sigma."""
|
||||||
return self.sigmas.max()
|
return self.sigmas.max()
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
|
||||||
"""Generate the timesteps used by the solver.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
We need to use numpy here because:
|
|
||||||
|
|
||||||
- numpy.linspace(0,999,31)[15] is 499.49999999999994
|
|
||||||
- torch.linspace(0,999,31)[15] is 499.5
|
|
||||||
|
|
||||||
and we want the same result as the original codebase.
|
|
||||||
"""
|
|
||||||
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
|
|
||||||
return timesteps
|
|
||||||
|
|
||||||
def _generate_sigmas(self) -> Tensor:
|
def _generate_sigmas(self) -> Tensor:
|
||||||
"""Generate the sigmas used by the solver."""
|
"""Generate the sigmas used by the solver."""
|
||||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||||
|
@ -88,12 +80,17 @@ class Euler(Solver):
|
||||||
sigma = self.sigmas[step]
|
sigma = self.sigmas[step]
|
||||||
return x / ((sigma**2 + 1) ** 0.5)
|
return x / ((sigma**2 + 1) ** 0.5)
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
self,
|
"""Apply one step of the backward diffusion process.
|
||||||
x: Tensor,
|
|
||||||
predicted_noise: Tensor,
|
Args:
|
||||||
step: int,
|
x: The input tensor to apply the diffusion process to.
|
||||||
generator: Generator | None = None,
|
predicted_noise: The predicted noise tensor for the current step.
|
||||||
) -> Tensor:
|
step: The current step of the diffusion process.
|
||||||
|
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The denoised version of the input data `x`.
|
||||||
|
"""
|
||||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||||
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])
|
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||||
|
|
||||||
|
|
||||||
class LCMSolver(Solver):
|
class LCMSolver(Solver):
|
||||||
|
@ -20,14 +19,29 @@ class LCMSolver(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
|
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING,
|
||||||
|
timesteps_offset: int = 0,
|
||||||
num_orig_steps: int = 50,
|
num_orig_steps: int = 50,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
diffusers_mode: bool = False,
|
|
||||||
device: torch.device | str = "cpu",
|
device: torch.device | str = "cpu",
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
):
|
):
|
||||||
|
"""Initializes a new LCM solver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps: The number of inference steps.
|
||||||
|
num_train_timesteps: The number of training timesteps.
|
||||||
|
timesteps_spacing: The spacing to use for the timesteps.
|
||||||
|
timesteps_offset: The offset to use for the timesteps.
|
||||||
|
num_orig_steps: The number of inference steps of the emulated DPM solver.
|
||||||
|
initial_diffusion_rate: The initial diffusion rate.
|
||||||
|
final_diffusion_rate: The final diffusion rate.
|
||||||
|
noise_schedule: The noise schedule.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
assert (
|
assert (
|
||||||
num_orig_steps >= num_inference_steps
|
num_orig_steps >= num_inference_steps
|
||||||
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
|
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
|
||||||
|
@ -36,22 +50,17 @@ class LCMSolver(Solver):
|
||||||
DPMSolver(
|
DPMSolver(
|
||||||
num_inference_steps=num_orig_steps,
|
num_inference_steps=num_orig_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
timesteps_spacing=timesteps_spacing,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if diffusers_mode:
|
|
||||||
# Diffusers recomputes the timesteps in LCMScheduler,
|
|
||||||
# and it does it slightly differently than DPM Solver.
|
|
||||||
# We provide this option to reproduce Diffusers' output.
|
|
||||||
k = num_train_timesteps // num_orig_steps
|
|
||||||
ts = np.asarray(list(range(1, num_orig_steps + 1))) * k - 1
|
|
||||||
self.dpm.timesteps = torch.tensor(ts, device=device).flip(0)
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
timesteps_spacing=timesteps_spacing,
|
||||||
|
timesteps_offset=timesteps_offset,
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
noise_schedule=noise_schedule,
|
noise_schedule=noise_schedule,
|
||||||
|
@ -85,12 +94,19 @@ class LCMSolver(Solver):
|
||||||
return self.dpm.timesteps[self.timestep_indices]
|
return self.dpm.timesteps[self.timestep_indices]
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
|
||||||
x: torch.Tensor,
|
|
||||||
predicted_noise: torch.Tensor,
|
|
||||||
step: int,
|
|
||||||
generator: torch.Generator | None = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Apply one step of the backward diffusion process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor to apply the diffusion process to.
|
||||||
|
predicted_noise: The predicted noise tensor for the current step.
|
||||||
|
step: The current step of the diffusion process.
|
||||||
|
generator: The random number generator to use for sampling noise.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The denoised version of the input data `x`.
|
||||||
|
"""
|
||||||
current_timestep = self.timesteps[step]
|
current_timestep = self.timesteps[step]
|
||||||
scale_factor = self.cumulative_scale_factors[current_timestep]
|
scale_factor = self.cumulative_scale_factors[current_timestep]
|
||||||
noise_ratio = self.noise_std[current_timestep]
|
noise_ratio = self.noise_std[current_timestep]
|
||||||
|
|
|
@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
import numpy as np
|
||||||
|
from torch import Generator, Tensor, arange, device as Device, dtype as DType, float32, linspace, log, sqrt, tensor
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
|
|
||||||
|
@ -10,11 +11,11 @@ T = TypeVar("T", bound="Solver")
|
||||||
|
|
||||||
|
|
||||||
class NoiseSchedule(str, Enum):
|
class NoiseSchedule(str, Enum):
|
||||||
"""An enumeration of noise schedules used to sample the noise schedule.
|
"""An enumeration of schedules used to sample the noise.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
UNIFORM: A uniform noise schedule.
|
UNIFORM: A uniform noise schedule.
|
||||||
QUADRATIC: A quadratic noise schedule.
|
QUADRATIC: A quadratic noise schedule. Corresponds to "Stable Diffusion" in [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 1.
|
||||||
KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364)
|
KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -23,6 +24,26 @@ class NoiseSchedule(str, Enum):
|
||||||
KARRAS = "karras"
|
KARRAS = "karras"
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepSpacing(str, Enum):
|
||||||
|
"""An enumeration of methods to space the timesteps.
|
||||||
|
|
||||||
|
See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor.
|
||||||
|
LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps.
|
||||||
|
LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM.
|
||||||
|
TRAILING: Sample N+1 steps, do not include the first timestep.
|
||||||
|
TRAILING_ALT: Variant of TRAILING used in DPM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LINSPACE_FLOAT = "linspace_float"
|
||||||
|
LINSPACE_INT = "linspace_int"
|
||||||
|
LEADING = "leading"
|
||||||
|
TRAILING = "trailing"
|
||||||
|
TRAILING_ALT = "trailing_alt"
|
||||||
|
|
||||||
|
|
||||||
class Solver(fl.Module, ABC):
|
class Solver(fl.Module, ABC):
|
||||||
"""The base class for creating a diffusion model solver.
|
"""The base class for creating a diffusion model solver.
|
||||||
|
|
||||||
|
@ -39,6 +60,8 @@ class Solver(fl.Module, ABC):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
|
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
|
||||||
|
timesteps_offset: int = 0,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
|
@ -51,6 +74,8 @@ class Solver(fl.Module, ABC):
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps to perform.
|
num_inference_steps: The number of inference steps to perform.
|
||||||
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
||||||
|
timesteps_spacing: The spacing to use for the timesteps.
|
||||||
|
timesteps_offset: The offset to use for the timesteps.
|
||||||
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
|
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
|
||||||
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
|
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
|
||||||
noise_schedule: The noise schedule used to sample the noise schedule.
|
noise_schedule: The noise schedule used to sample the noise schedule.
|
||||||
|
@ -61,6 +86,8 @@ class Solver(fl.Module, ABC):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
self.num_train_timesteps = num_train_timesteps
|
self.num_train_timesteps = num_train_timesteps
|
||||||
|
self.timesteps_spacing = timesteps_spacing
|
||||||
|
self.timesteps_offset = timesteps_offset
|
||||||
self.initial_diffusion_rate = initial_diffusion_rate
|
self.initial_diffusion_rate = initial_diffusion_rate
|
||||||
self.final_diffusion_rate = final_diffusion_rate
|
self.final_diffusion_rate = final_diffusion_rate
|
||||||
self.noise_schedule = noise_schedule
|
self.noise_schedule = noise_schedule
|
||||||
|
@ -87,14 +114,49 @@ class Solver(fl.Module, ABC):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@staticmethod
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def generate_timesteps(
|
||||||
"""Generate a tensor of timesteps.
|
spacing: TimestepSpacing,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1000,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Generate a tensor of timesteps according to a given spacing.
|
||||||
|
|
||||||
Note:
|
Args:
|
||||||
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
|
spacing: The spacing to use for the timesteps.
|
||||||
|
num_inference_steps: The number of inference steps to perform.
|
||||||
|
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
||||||
|
offset: The offset to use for the timesteps.
|
||||||
"""
|
"""
|
||||||
...
|
max_timestep = num_train_timesteps - 1 + offset
|
||||||
|
match spacing:
|
||||||
|
case TimestepSpacing.LINSPACE_FLOAT:
|
||||||
|
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
|
||||||
|
case TimestepSpacing.LINSPACE_INT:
|
||||||
|
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
|
||||||
|
case TimestepSpacing.LEADING:
|
||||||
|
step_ratio = num_train_timesteps // num_inference_steps
|
||||||
|
return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
|
||||||
|
case TimestepSpacing.TRAILING:
|
||||||
|
step_ratio = num_train_timesteps // num_inference_steps
|
||||||
|
max_timestep = num_train_timesteps - 1 + offset
|
||||||
|
return arange(max_timestep, offset, -step_ratio)
|
||||||
|
case TimestepSpacing.TRAILING_ALT:
|
||||||
|
# We use numpy here because:
|
||||||
|
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||||
|
# torch.linspace(0,999,31)[15] is 499.5
|
||||||
|
# and we want the same result as the original DPM codebase.
|
||||||
|
np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:]
|
||||||
|
return tensor(np_space).flip(0)
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
return self.generate_timesteps(
|
||||||
|
spacing=self.timesteps_spacing,
|
||||||
|
num_inference_steps=self.num_inference_steps,
|
||||||
|
num_train_timesteps=self.num_train_timesteps,
|
||||||
|
offset=self.timesteps_offset,
|
||||||
|
)
|
||||||
|
|
||||||
def add_noise(
|
def add_noise(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -2,10 +2,19 @@ from typing import cast
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn
|
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn, tensor
|
||||||
|
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, LCMSolver, NoiseSchedule
|
from refiners.foundationals.latent_diffusion.solvers import (
|
||||||
|
DDIM,
|
||||||
|
DDPM,
|
||||||
|
DPMSolver,
|
||||||
|
Euler,
|
||||||
|
LCMSolver,
|
||||||
|
NoiseSchedule,
|
||||||
|
Solver,
|
||||||
|
TimestepSpacing,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_ddpm_diffusers():
|
def test_ddpm_diffusers():
|
||||||
|
@ -14,7 +23,6 @@ def test_ddpm_diffusers():
|
||||||
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
|
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
|
||||||
diffusers_scheduler.set_timesteps(1000)
|
diffusers_scheduler.set_timesteps(1000)
|
||||||
refiners_scheduler = DDPM(num_inference_steps=1000)
|
refiners_scheduler = DDPM(num_inference_steps=1000)
|
||||||
|
|
||||||
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
|
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,6 +42,7 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
||||||
)
|
)
|
||||||
diffusers_scheduler.set_timesteps(n_steps)
|
diffusers_scheduler.set_timesteps(n_steps)
|
||||||
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
|
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
|
||||||
|
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
sample = randn(1, 3, 32, 32)
|
sample = randn(1, 3, 32, 32)
|
||||||
predicted_noise = randn(1, 3, 32, 32)
|
predicted_noise = randn(1, 3, 32, 32)
|
||||||
|
@ -59,6 +68,7 @@ def test_ddim_diffusers():
|
||||||
)
|
)
|
||||||
diffusers_scheduler.set_timesteps(30)
|
diffusers_scheduler.set_timesteps(30)
|
||||||
refiners_scheduler = DDIM(num_inference_steps=30)
|
refiners_scheduler = DDIM(num_inference_steps=30)
|
||||||
|
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
sample = randn(1, 4, 32, 32)
|
sample = randn(1, 4, 32, 32)
|
||||||
predicted_noise = randn(1, 4, 32, 32)
|
predicted_noise = randn(1, 4, 32, 32)
|
||||||
|
@ -85,6 +95,7 @@ def test_euler_diffusers():
|
||||||
)
|
)
|
||||||
diffusers_scheduler.set_timesteps(30)
|
diffusers_scheduler.set_timesteps(30)
|
||||||
refiners_scheduler = Euler(num_inference_steps=30)
|
refiners_scheduler = Euler(num_inference_steps=30)
|
||||||
|
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
sample = randn(1, 4, 32, 32)
|
sample = randn(1, 4, 32, 32)
|
||||||
predicted_noise = randn(1, 4, 32, 32)
|
predicted_noise = randn(1, 4, 32, 32)
|
||||||
|
@ -111,9 +122,7 @@ def test_lcm_diffusers():
|
||||||
|
|
||||||
diffusers_scheduler = LCMScheduler()
|
diffusers_scheduler = LCMScheduler()
|
||||||
diffusers_scheduler.set_timesteps(4)
|
diffusers_scheduler.set_timesteps(4)
|
||||||
refiners_scheduler = LCMSolver(num_inference_steps=4, diffusers_mode=True)
|
refiners_scheduler = LCMSolver(num_inference_steps=4)
|
||||||
|
|
||||||
# diffusers_mode means the timesteps are the same
|
|
||||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
sample = randn(1, 4, 32, 32)
|
sample = randn(1, 4, 32, 32)
|
||||||
|
@ -143,7 +152,7 @@ def test_lcm_diffusers():
|
||||||
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_remove_noise():
|
def test_solver_remove_noise():
|
||||||
from diffusers import DDIMScheduler # type: ignore
|
from diffusers import DDIMScheduler # type: ignore
|
||||||
|
|
||||||
manual_seed(0)
|
manual_seed(0)
|
||||||
|
@ -169,7 +178,7 @@ def test_scheduler_remove_noise():
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_device(test_device: Device):
|
def test_solver_device(test_device: Device):
|
||||||
if test_device.type == "cpu":
|
if test_device.type == "cpu":
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
@ -182,8 +191,35 @@ def test_scheduler_device(test_device: Device):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
||||||
def test_scheduler_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
||||||
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
|
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
|
||||||
assert len(scheduler.scale_factors) == 1000
|
assert len(scheduler.scale_factors) == 1000
|
||||||
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
|
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
|
||||||
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate
|
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate
|
||||||
|
|
||||||
|
|
||||||
|
def test_solver_timestep_spacing():
|
||||||
|
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
||||||
|
linspace_int = Solver.generate_timesteps(
|
||||||
|
spacing=TimestepSpacing.LINSPACE_INT,
|
||||||
|
num_inference_steps=10,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
offset=1,
|
||||||
|
)
|
||||||
|
assert equal(linspace_int, tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1]))
|
||||||
|
|
||||||
|
leading = Solver.generate_timesteps(
|
||||||
|
spacing=TimestepSpacing.LEADING,
|
||||||
|
num_inference_steps=10,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
offset=1,
|
||||||
|
)
|
||||||
|
assert equal(leading, tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1]))
|
||||||
|
|
||||||
|
trailing = Solver.generate_timesteps(
|
||||||
|
spacing=TimestepSpacing.TRAILING,
|
||||||
|
num_inference_steps=10,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
offset=1,
|
||||||
|
)
|
||||||
|
assert equal(trailing, tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))
|
||||||
|
|
Loading…
Reference in a new issue