mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
refactor solvers to support different timestep spacings
This commit is contained in:
parent
d14c5bd5f8
commit
ddc1cf8ca7
|
@ -3,6 +3,6 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
|||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
|
||||
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule"]
|
||||
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
|
||||
|
||||
class DDIM(Solver):
|
||||
|
@ -13,6 +13,8 @@ class DDIM(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
|
||||
timesteps_offset: int = 1,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
|
@ -25,6 +27,8 @@ class DDIM(Solver):
|
|||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
noise_schedule: The noise schedule.
|
||||
|
@ -35,6 +39,8 @@ class DDIM(Solver):
|
|||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
|
@ -43,16 +49,18 @@ class DDIM(Solver):
|
|||
dtype=dtype,
|
||||
)
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
"""
|
||||
Generates decreasing timesteps with 'leading' spacing and offset of 1
|
||||
similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5
|
||||
"""
|
||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
|
||||
return timesteps.flip(0)
|
||||
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""Apply one step of the backward diffusion process.
|
||||
|
||||
Args:
|
||||
x: The input tensor to apply the diffusion process to.
|
||||
predicted_noise: The predicted noise tensor for the current step.
|
||||
step: The current step of the diffusion process.
|
||||
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
|
||||
|
||||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
"""
|
||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||
|
||||
timestep, previous_timestep = (
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from torch import Generator, Tensor, arange, device as Device
|
||||
from torch import Generator, Tensor, device as Device
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
||||
|
||||
|
||||
class DDPM(Solver):
|
||||
|
@ -17,6 +17,8 @@ class DDPM(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
|
||||
timesteps_offset: int = 0,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
first_inference_step: int = 0,
|
||||
|
@ -27,6 +29,8 @@ class DDPM(Solver):
|
|||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
first_inference_step: The first inference step.
|
||||
|
@ -35,16 +39,13 @@ class DDPM(Solver):
|
|||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
first_inference_step=first_inference_step,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
|
||||
return timesteps.flip(0)
|
||||
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
|
||||
|
||||
class DPMSolver(Solver):
|
||||
|
@ -23,6 +22,8 @@ class DPMSolver(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT,
|
||||
timesteps_offset: int = 0,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
last_step_first_order: bool = False,
|
||||
|
@ -31,9 +32,26 @@ class DPMSolver(Solver):
|
|||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
"""Initializes a new DPM solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
last_step_first_order: Use a first-order update for the last step.
|
||||
noise_schedule: The noise schedule.
|
||||
first_inference_step: The first inference step.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
|
@ -44,21 +62,6 @@ class DPMSolver(Solver):
|
|||
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||
self.last_step_first_order = last_step_first_order
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
"""Generate the timesteps used by the solver.
|
||||
|
||||
Note:
|
||||
We need to use numpy here because:
|
||||
|
||||
- numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
- torch.linspace(0,999,31)[15] is 499.5
|
||||
|
||||
and we want the same result as the original codebase.
|
||||
"""
|
||||
return tensor(
|
||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
||||
).flip(0)
|
||||
|
||||
def rebuild(
|
||||
self: "DPMSolver",
|
||||
num_inference_steps: int | None,
|
||||
|
@ -148,10 +151,10 @@ class DPMSolver(Solver):
|
|||
(ODEs).
|
||||
|
||||
Args:
|
||||
x: The input data.
|
||||
predicted_noise: The predicted noise.
|
||||
step: The current step.
|
||||
generator: The random number generator.
|
||||
x: The input tensor to apply the diffusion process to.
|
||||
predicted_noise: The predicted noise tensor for the current step.
|
||||
step: The current step of the diffusion process.
|
||||
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
|
||||
|
||||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
|
|
|
@ -2,7 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
|
||||
|
||||
class Euler(Solver):
|
||||
|
@ -16,6 +16,8 @@ class Euler(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
|
||||
timesteps_offset: int = 0,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
|
@ -28,6 +30,8 @@ class Euler(Solver):
|
|||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
noise_schedule: The noise schedule.
|
||||
|
@ -40,6 +44,8 @@ class Euler(Solver):
|
|||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
|
@ -54,20 +60,6 @@ class Euler(Solver):
|
|||
"""The initial noise sigma."""
|
||||
return self.sigmas.max()
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
"""Generate the timesteps used by the solver.
|
||||
|
||||
Note:
|
||||
We need to use numpy here because:
|
||||
|
||||
- numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
- torch.linspace(0,999,31)[15] is 499.5
|
||||
|
||||
and we want the same result as the original codebase.
|
||||
"""
|
||||
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
|
||||
return timesteps
|
||||
|
||||
def _generate_sigmas(self) -> Tensor:
|
||||
"""Generate the sigmas used by the solver."""
|
||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||
|
@ -88,12 +80,17 @@ class Euler(Solver):
|
|||
sigma = self.sigmas[step]
|
||||
return x / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
predicted_noise: Tensor,
|
||||
step: int,
|
||||
generator: Generator | None = None,
|
||||
) -> Tensor:
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""Apply one step of the backward diffusion process.
|
||||
|
||||
Args:
|
||||
x: The input tensor to apply the diffusion process to.
|
||||
predicted_noise: The predicted noise tensor for the current step.
|
||||
step: The current step of the diffusion process.
|
||||
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
|
||||
|
||||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
"""
|
||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
||||
|
||||
|
||||
class LCMSolver(Solver):
|
||||
|
@ -20,14 +19,29 @@ class LCMSolver(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING,
|
||||
timesteps_offset: int = 0,
|
||||
num_orig_steps: int = 50,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
diffusers_mode: bool = False,
|
||||
device: torch.device | str = "cpu",
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
"""Initializes a new LCM solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
num_orig_steps: The number of inference steps of the emulated DPM solver.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
noise_schedule: The noise schedule.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
assert (
|
||||
num_orig_steps >= num_inference_steps
|
||||
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
|
||||
|
@ -36,22 +50,17 @@ class LCMSolver(Solver):
|
|||
DPMSolver(
|
||||
num_inference_steps=num_orig_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
]
|
||||
|
||||
if diffusers_mode:
|
||||
# Diffusers recomputes the timesteps in LCMScheduler,
|
||||
# and it does it slightly differently than DPM Solver.
|
||||
# We provide this option to reproduce Diffusers' output.
|
||||
k = num_train_timesteps // num_orig_steps
|
||||
ts = np.asarray(list(range(1, num_orig_steps + 1))) * k - 1
|
||||
self.dpm.timesteps = torch.tensor(ts, device=device).flip(0)
|
||||
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
timesteps_spacing=timesteps_spacing,
|
||||
timesteps_offset=timesteps_offset,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
|
@ -85,12 +94,19 @@ class LCMSolver(Solver):
|
|||
return self.dpm.timesteps[self.timestep_indices]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
predicted_noise: torch.Tensor,
|
||||
step: int,
|
||||
generator: torch.Generator | None = None,
|
||||
self, x: torch.Tensor, predicted_noise: torch.Tensor, step: int, generator: torch.Generator | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Apply one step of the backward diffusion process.
|
||||
|
||||
Args:
|
||||
x: The input tensor to apply the diffusion process to.
|
||||
predicted_noise: The predicted noise tensor for the current step.
|
||||
step: The current step of the diffusion process.
|
||||
generator: The random number generator to use for sampling noise.
|
||||
|
||||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
"""
|
||||
current_timestep = self.timesteps[step]
|
||||
scale_factor = self.cumulative_scale_factors[current_timestep]
|
||||
noise_ratio = self.noise_std[current_timestep]
|
||||
|
|
|
@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
|||
from enum import Enum
|
||||
from typing import TypeVar
|
||||
|
||||
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
||||
import numpy as np
|
||||
from torch import Generator, Tensor, arange, device as Device, dtype as DType, float32, linspace, log, sqrt, tensor
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
|
||||
|
@ -10,11 +11,11 @@ T = TypeVar("T", bound="Solver")
|
|||
|
||||
|
||||
class NoiseSchedule(str, Enum):
|
||||
"""An enumeration of noise schedules used to sample the noise schedule.
|
||||
"""An enumeration of schedules used to sample the noise.
|
||||
|
||||
Attributes:
|
||||
UNIFORM: A uniform noise schedule.
|
||||
QUADRATIC: A quadratic noise schedule.
|
||||
QUADRATIC: A quadratic noise schedule. Corresponds to "Stable Diffusion" in [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 1.
|
||||
KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364)
|
||||
"""
|
||||
|
||||
|
@ -23,6 +24,26 @@ class NoiseSchedule(str, Enum):
|
|||
KARRAS = "karras"
|
||||
|
||||
|
||||
class TimestepSpacing(str, Enum):
|
||||
"""An enumeration of methods to space the timesteps.
|
||||
|
||||
See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
||||
|
||||
Attributes:
|
||||
LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor.
|
||||
LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps.
|
||||
LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM.
|
||||
TRAILING: Sample N+1 steps, do not include the first timestep.
|
||||
TRAILING_ALT: Variant of TRAILING used in DPM.
|
||||
"""
|
||||
|
||||
LINSPACE_FLOAT = "linspace_float"
|
||||
LINSPACE_INT = "linspace_int"
|
||||
LEADING = "leading"
|
||||
TRAILING = "trailing"
|
||||
TRAILING_ALT = "trailing_alt"
|
||||
|
||||
|
||||
class Solver(fl.Module, ABC):
|
||||
"""The base class for creating a diffusion model solver.
|
||||
|
||||
|
@ -39,6 +60,8 @@ class Solver(fl.Module, ABC):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
|
||||
timesteps_offset: int = 0,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
|
@ -51,6 +74,8 @@ class Solver(fl.Module, ABC):
|
|||
Args:
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
||||
timesteps_spacing: The spacing to use for the timesteps.
|
||||
timesteps_offset: The offset to use for the timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
|
||||
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
|
||||
noise_schedule: The noise schedule used to sample the noise schedule.
|
||||
|
@ -61,6 +86,8 @@ class Solver(fl.Module, ABC):
|
|||
super().__init__()
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps_spacing = timesteps_spacing
|
||||
self.timesteps_offset = timesteps_offset
|
||||
self.initial_diffusion_rate = initial_diffusion_rate
|
||||
self.final_diffusion_rate = final_diffusion_rate
|
||||
self.noise_schedule = noise_schedule
|
||||
|
@ -87,14 +114,49 @@ class Solver(fl.Module, ABC):
|
|||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
"""Generate a tensor of timesteps.
|
||||
@staticmethod
|
||||
def generate_timesteps(
|
||||
spacing: TimestepSpacing,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1000,
|
||||
offset: int = 0,
|
||||
) -> Tensor:
|
||||
"""Generate a tensor of timesteps according to a given spacing.
|
||||
|
||||
Note:
|
||||
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
|
||||
Args:
|
||||
spacing: The spacing to use for the timesteps.
|
||||
num_inference_steps: The number of inference steps to perform.
|
||||
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
||||
offset: The offset to use for the timesteps.
|
||||
"""
|
||||
...
|
||||
max_timestep = num_train_timesteps - 1 + offset
|
||||
match spacing:
|
||||
case TimestepSpacing.LINSPACE_FLOAT:
|
||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
|
||||
case TimestepSpacing.LINSPACE_INT:
|
||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
|
||||
case TimestepSpacing.LEADING:
|
||||
step_ratio = num_train_timesteps // num_inference_steps
|
||||
return (arange(0, num_inference_steps, 1) * step_ratio + offset).flip(0)
|
||||
case TimestepSpacing.TRAILING:
|
||||
step_ratio = num_train_timesteps // num_inference_steps
|
||||
max_timestep = num_train_timesteps - 1 + offset
|
||||
return arange(max_timestep, offset, -step_ratio)
|
||||
case TimestepSpacing.TRAILING_ALT:
|
||||
# We use numpy here because:
|
||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
# torch.linspace(0,999,31)[15] is 499.5
|
||||
# and we want the same result as the original DPM codebase.
|
||||
np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:]
|
||||
return tensor(np_space).flip(0)
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
return self.generate_timesteps(
|
||||
spacing=self.timesteps_spacing,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
offset=self.timesteps_offset,
|
||||
)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
|
|
|
@ -2,10 +2,19 @@ from typing import cast
|
|||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn
|
||||
from torch import Generator, Tensor, allclose, device as Device, equal, isclose, randn, tensor
|
||||
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, DDPM, DPMSolver, Euler, LCMSolver, NoiseSchedule
|
||||
from refiners.foundationals.latent_diffusion.solvers import (
|
||||
DDIM,
|
||||
DDPM,
|
||||
DPMSolver,
|
||||
Euler,
|
||||
LCMSolver,
|
||||
NoiseSchedule,
|
||||
Solver,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
||||
def test_ddpm_diffusers():
|
||||
|
@ -14,7 +23,6 @@ def test_ddpm_diffusers():
|
|||
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
|
||||
diffusers_scheduler.set_timesteps(1000)
|
||||
refiners_scheduler = DDPM(num_inference_steps=1000)
|
||||
|
||||
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
|
||||
|
||||
|
||||
|
@ -34,6 +42,7 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
|||
)
|
||||
diffusers_scheduler.set_timesteps(n_steps)
|
||||
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
|
||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 3, 32, 32)
|
||||
predicted_noise = randn(1, 3, 32, 32)
|
||||
|
@ -59,6 +68,7 @@ def test_ddim_diffusers():
|
|||
)
|
||||
diffusers_scheduler.set_timesteps(30)
|
||||
refiners_scheduler = DDIM(num_inference_steps=30)
|
||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 4, 32, 32)
|
||||
predicted_noise = randn(1, 4, 32, 32)
|
||||
|
@ -85,6 +95,7 @@ def test_euler_diffusers():
|
|||
)
|
||||
diffusers_scheduler.set_timesteps(30)
|
||||
refiners_scheduler = Euler(num_inference_steps=30)
|
||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 4, 32, 32)
|
||||
predicted_noise = randn(1, 4, 32, 32)
|
||||
|
@ -111,9 +122,7 @@ def test_lcm_diffusers():
|
|||
|
||||
diffusers_scheduler = LCMScheduler()
|
||||
diffusers_scheduler.set_timesteps(4)
|
||||
refiners_scheduler = LCMSolver(num_inference_steps=4, diffusers_mode=True)
|
||||
|
||||
# diffusers_mode means the timesteps are the same
|
||||
refiners_scheduler = LCMSolver(num_inference_steps=4)
|
||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||
|
||||
sample = randn(1, 4, 32, 32)
|
||||
|
@ -143,7 +152,7 @@ def test_lcm_diffusers():
|
|||
assert allclose(refiners_output, diffusers_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_scheduler_remove_noise():
|
||||
def test_solver_remove_noise():
|
||||
from diffusers import DDIMScheduler # type: ignore
|
||||
|
||||
manual_seed(0)
|
||||
|
@ -169,7 +178,7 @@ def test_scheduler_remove_noise():
|
|||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_scheduler_device(test_device: Device):
|
||||
def test_solver_device(test_device: Device):
|
||||
if test_device.type == "cpu":
|
||||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
@ -182,8 +191,35 @@ def test_scheduler_device(test_device: Device):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
||||
def test_scheduler_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
||||
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
||||
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
|
||||
assert len(scheduler.scale_factors) == 1000
|
||||
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
|
||||
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate
|
||||
|
||||
|
||||
def test_solver_timestep_spacing():
|
||||
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
||||
linspace_int = Solver.generate_timesteps(
|
||||
spacing=TimestepSpacing.LINSPACE_INT,
|
||||
num_inference_steps=10,
|
||||
num_train_timesteps=1000,
|
||||
offset=1,
|
||||
)
|
||||
assert equal(linspace_int, tensor([1000, 889, 778, 667, 556, 445, 334, 223, 112, 1]))
|
||||
|
||||
leading = Solver.generate_timesteps(
|
||||
spacing=TimestepSpacing.LEADING,
|
||||
num_inference_steps=10,
|
||||
num_train_timesteps=1000,
|
||||
offset=1,
|
||||
)
|
||||
assert equal(leading, tensor([901, 801, 701, 601, 501, 401, 301, 201, 101, 1]))
|
||||
|
||||
trailing = Solver.generate_timesteps(
|
||||
spacing=TimestepSpacing.TRAILING,
|
||||
num_inference_steps=10,
|
||||
num_train_timesteps=1000,
|
||||
offset=1,
|
||||
)
|
||||
assert equal(trailing, tensor([1000, 900, 800, 700, 600, 500, 400, 300, 200, 100]))
|
||||
|
|
Loading…
Reference in a new issue