mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
refactor solver params, add sample prediction type
This commit is contained in:
parent
ddc1cf8ca7
commit
bf0ba58541
|
@ -3,6 +3,23 @@ from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
|
||||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
|
||||||
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
from refiners.foundationals.latent_diffusion.solvers.lcm import LCMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
ModelPredictionType,
|
||||||
|
NoiseSchedule,
|
||||||
|
Solver,
|
||||||
|
SolverParams,
|
||||||
|
TimestepSpacing,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler", "LCMSolver", "NoiseSchedule", "TimestepSpacing"]
|
__all__ = [
|
||||||
|
"Solver",
|
||||||
|
"SolverParams",
|
||||||
|
"DPMSolver",
|
||||||
|
"DDPM",
|
||||||
|
"DDIM",
|
||||||
|
"Euler",
|
||||||
|
"LCMSolver",
|
||||||
|
"ModelPredictionType",
|
||||||
|
"NoiseSchedule",
|
||||||
|
"TimestepSpacing",
|
||||||
|
]
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
ModelPredictionType,
|
||||||
|
Solver,
|
||||||
|
SolverParams,
|
||||||
|
TimestepSpacing,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DDIM(Solver):
|
class DDIM(Solver):
|
||||||
|
@ -9,42 +16,36 @@ class DDIM(Solver):
|
||||||
See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details.
|
See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
default_params = dataclasses.replace(
|
||||||
|
Solver.default_params,
|
||||||
|
timesteps_spacing=TimestepSpacing.LEADING,
|
||||||
|
timesteps_offset=1,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
|
||||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
|
|
||||||
timesteps_offset: int = 1,
|
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
|
params: SolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes a new DDIM solver.
|
"""Initializes a new DDIM solver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps.
|
num_inference_steps: The number of inference steps to perform.
|
||||||
num_train_timesteps: The number of training timesteps.
|
first_inference_step: The first inference step to perform.
|
||||||
timesteps_spacing: The spacing to use for the timesteps.
|
params: The common parameters for solvers.
|
||||||
timesteps_offset: The offset to use for the timesteps.
|
|
||||||
initial_diffusion_rate: The initial diffusion rate.
|
|
||||||
final_diffusion_rate: The final diffusion rate.
|
|
||||||
noise_schedule: The noise schedule.
|
|
||||||
first_inference_step: The first inference step.
|
|
||||||
device: The PyTorch device to use.
|
device: The PyTorch device to use.
|
||||||
dtype: The PyTorch data type to use.
|
dtype: The PyTorch data type to use.
|
||||||
"""
|
"""
|
||||||
|
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
|
||||||
timesteps_spacing=timesteps_spacing,
|
|
||||||
timesteps_offset=timesteps_offset,
|
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
|
||||||
noise_schedule=noise_schedule,
|
|
||||||
first_inference_step=first_inference_step,
|
first_inference_step=first_inference_step,
|
||||||
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
from torch import Generator, Tensor, device as Device
|
from torch import Generator, Tensor, device as Device
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver, TimestepSpacing
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
ModelPredictionType,
|
||||||
|
Solver,
|
||||||
|
SolverParams,
|
||||||
|
TimestepSpacing,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DDPM(Solver):
|
class DDPM(Solver):
|
||||||
|
@ -13,37 +20,34 @@ class DDPM(Solver):
|
||||||
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details.
|
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
default_params = dataclasses.replace(
|
||||||
|
Solver.default_params,
|
||||||
|
timesteps_spacing=TimestepSpacing.LEADING,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
|
||||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LEADING,
|
|
||||||
timesteps_offset: int = 0,
|
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
|
params: SolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes a new DDPM solver.
|
"""Initializes a new DDPM solver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps.
|
num_inference_steps: The number of inference steps to perform.
|
||||||
num_train_timesteps: The number of training timesteps.
|
first_inference_step: The first inference step to perform.
|
||||||
timesteps_spacing: The spacing to use for the timesteps.
|
params: The common parameters for solvers.
|
||||||
timesteps_offset: The offset to use for the timesteps.
|
|
||||||
initial_diffusion_rate: The initial diffusion rate.
|
|
||||||
final_diffusion_rate: The final diffusion rate.
|
|
||||||
first_inference_step: The first inference step.
|
|
||||||
device: The PyTorch device to use.
|
device: The PyTorch device to use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
|
||||||
timesteps_spacing=timesteps_spacing,
|
|
||||||
timesteps_offset=timesteps_offset,
|
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
|
||||||
first_inference_step=first_inference_step,
|
first_inference_step=first_inference_step,
|
||||||
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
|
import dataclasses
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
ModelPredictionType,
|
||||||
|
Solver,
|
||||||
|
SolverParams,
|
||||||
|
TimestepSpacing,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DPMSolver(Solver):
|
class DPMSolver(Solver):
|
||||||
|
@ -18,44 +25,37 @@ class DPMSolver(Solver):
|
||||||
for the last step of the diffusion.
|
for the last step of the diffusion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
default_params = dataclasses.replace(
|
||||||
|
Solver.default_params,
|
||||||
|
timesteps_spacing=TimestepSpacing.CUSTOM,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
|
||||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING_ALT,
|
|
||||||
timesteps_offset: int = 0,
|
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
|
||||||
last_step_first_order: bool = False,
|
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
|
params: SolverParams | None = None,
|
||||||
|
last_step_first_order: bool = False,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
):
|
):
|
||||||
"""Initializes a new DPM solver.
|
"""Initializes a new DPM solver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps.
|
num_inference_steps: The number of inference steps to perform.
|
||||||
num_train_timesteps: The number of training timesteps.
|
first_inference_step: The first inference step to perform.
|
||||||
timesteps_spacing: The spacing to use for the timesteps.
|
params: The common parameters for solvers.
|
||||||
timesteps_offset: The offset to use for the timesteps.
|
|
||||||
initial_diffusion_rate: The initial diffusion rate.
|
|
||||||
final_diffusion_rate: The final diffusion rate.
|
|
||||||
last_step_first_order: Use a first-order update for the last step.
|
last_step_first_order: Use a first-order update for the last step.
|
||||||
noise_schedule: The noise schedule.
|
|
||||||
first_inference_step: The first inference step.
|
|
||||||
device: The PyTorch device to use.
|
device: The PyTorch device to use.
|
||||||
dtype: The PyTorch data type to use.
|
dtype: The PyTorch data type to use.
|
||||||
"""
|
"""
|
||||||
|
if params and params.model_prediction_type not in (ModelPredictionType.NOISE, None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
|
||||||
timesteps_spacing=timesteps_spacing,
|
|
||||||
timesteps_offset=timesteps_offset,
|
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
|
||||||
noise_schedule=noise_schedule,
|
|
||||||
first_inference_step=first_inference_step,
|
first_inference_step=first_inference_step,
|
||||||
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -80,6 +80,19 @@ class DPMSolver(Solver):
|
||||||
r.last_step_first_order = self.last_step_first_order
|
r.last_step_first_order = self.last_step_first_order
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
if self.params.timesteps_spacing != TimestepSpacing.CUSTOM:
|
||||||
|
return super()._generate_timesteps()
|
||||||
|
|
||||||
|
# We use numpy here because:
|
||||||
|
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||||
|
# torch.linspace(0,999,31)[15] is 499.5
|
||||||
|
# and we want the same result as the original DPM codebase.
|
||||||
|
offset = self.params.timesteps_offset
|
||||||
|
max_timestep = self.params.num_train_timesteps - 1 + offset
|
||||||
|
np_space = np.linspace(offset, max_timestep, self.num_inference_steps + 1).round().astype(int)[1:]
|
||||||
|
return tensor(np_space).flip(0)
|
||||||
|
|
||||||
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
"""Applies a first-order backward Euler update to the input data `x`.
|
"""Applies a first-order backward Euler update to the input data `x`.
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,12 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
ModelPredictionType,
|
||||||
|
NoiseSchedule,
|
||||||
|
Solver,
|
||||||
|
SolverParams,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Euler(Solver):
|
class Euler(Solver):
|
||||||
|
@ -15,41 +20,27 @@ class Euler(Solver):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
|
||||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
|
|
||||||
timesteps_offset: int = 0,
|
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
|
params: SolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
):
|
):
|
||||||
"""Initializes a new Euler solver.
|
"""Initializes a new Euler solver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps.
|
num_inference_steps: The number of inference steps to perform.
|
||||||
num_train_timesteps: The number of training timesteps.
|
first_inference_step: The first inference step to perform.
|
||||||
timesteps_spacing: The spacing to use for the timesteps.
|
params: The common parameters for solvers.
|
||||||
timesteps_offset: The offset to use for the timesteps.
|
|
||||||
initial_diffusion_rate: The initial diffusion rate.
|
|
||||||
final_diffusion_rate: The final diffusion rate.
|
|
||||||
noise_schedule: The noise schedule.
|
|
||||||
first_inference_step: The first inference step.
|
|
||||||
device: The PyTorch device to use.
|
device: The PyTorch device to use.
|
||||||
dtype: The PyTorch data type to use.
|
dtype: The PyTorch data type to use.
|
||||||
"""
|
"""
|
||||||
if noise_schedule != NoiseSchedule.QUADRATIC:
|
if params and params.noise_schedule not in (NoiseSchedule.QUADRATIC, None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
|
||||||
timesteps_spacing=timesteps_spacing,
|
|
||||||
timesteps_offset=timesteps_offset,
|
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
|
||||||
noise_schedule=noise_schedule,
|
|
||||||
first_inference_step=first_inference_step,
|
first_inference_step=first_inference_step,
|
||||||
|
params=params,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
@ -85,7 +76,7 @@ class Euler(Solver):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: The input tensor to apply the diffusion process to.
|
x: The input tensor to apply the diffusion process to.
|
||||||
predicted_noise: The predicted noise tensor for the current step.
|
predicted_noise: The predicted noise tensor for the current step (or x0 if the prediction type is SAMPLE).
|
||||||
step: The current step of the diffusion process.
|
step: The current step of the diffusion process.
|
||||||
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
|
generator: The random number generator to use for sampling noise (ignored, this solver is deterministic).
|
||||||
|
|
||||||
|
@ -93,4 +84,11 @@ class Euler(Solver):
|
||||||
The denoised version of the input data `x`.
|
The denoised version of the input data `x`.
|
||||||
"""
|
"""
|
||||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||||
|
|
||||||
|
if self.params.model_prediction_type == ModelPredictionType.SAMPLE:
|
||||||
|
x0 = predicted_noise # the model does not actually predict the noise but x0
|
||||||
|
ratio = self.sigmas[step + 1] / self.sigmas[step]
|
||||||
|
return ratio * x + (1 - ratio) * x0
|
||||||
|
|
||||||
|
assert self.params.model_prediction_type == ModelPredictionType.NOISE
|
||||||
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])
|
return x + predicted_noise * (self.sigmas[step + 1] - self.sigmas[step])
|
||||||
|
|
|
@ -1,7 +1,14 @@
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver, TimestepSpacing
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
ModelPredictionType,
|
||||||
|
Solver,
|
||||||
|
SolverParams,
|
||||||
|
TimestepSpacing,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LCMSolver(Solver):
|
class LCMSolver(Solver):
|
||||||
|
@ -15,55 +22,52 @@ class LCMSolver(Solver):
|
||||||
for details.
|
for details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# The spacing parameter is actually the spacing of the underlying DPM solver.
|
||||||
|
default_params = dataclasses.replace(Solver.default_params, timesteps_spacing=TimestepSpacing.TRAILING)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
first_inference_step: int = 0,
|
||||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.TRAILING,
|
params: SolverParams | None = None,
|
||||||
timesteps_offset: int = 0,
|
|
||||||
num_orig_steps: int = 50,
|
num_orig_steps: int = 50,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
|
||||||
device: torch.device | str = "cpu",
|
device: torch.device | str = "cpu",
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
):
|
):
|
||||||
"""Initializes a new LCM solver.
|
"""Initializes a new LCM solver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps.
|
num_inference_steps: The number of inference steps to perform.
|
||||||
num_train_timesteps: The number of training timesteps.
|
first_inference_step: The first inference step to perform.
|
||||||
timesteps_spacing: The spacing to use for the timesteps.
|
params: The common parameters for solvers.
|
||||||
timesteps_offset: The offset to use for the timesteps.
|
|
||||||
num_orig_steps: The number of inference steps of the emulated DPM solver.
|
num_orig_steps: The number of inference steps of the emulated DPM solver.
|
||||||
initial_diffusion_rate: The initial diffusion rate.
|
|
||||||
final_diffusion_rate: The final diffusion rate.
|
|
||||||
noise_schedule: The noise schedule.
|
|
||||||
device: The PyTorch device to use.
|
device: The PyTorch device to use.
|
||||||
dtype: The PyTorch data type to use.
|
dtype: The PyTorch data type to use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
num_orig_steps >= num_inference_steps
|
num_orig_steps >= num_inference_steps
|
||||||
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
|
), f"num_orig_steps ({num_orig_steps}) < num_inference_steps ({num_inference_steps})"
|
||||||
|
|
||||||
|
params = self.resolve_params(params)
|
||||||
|
if params.model_prediction_type != ModelPredictionType.NOISE:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
self._dpm = [
|
self._dpm = [
|
||||||
DPMSolver(
|
DPMSolver(
|
||||||
num_inference_steps=num_orig_steps,
|
num_inference_steps=num_orig_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
params=SolverParams(
|
||||||
timesteps_spacing=timesteps_spacing,
|
num_train_timesteps=params.num_train_timesteps,
|
||||||
|
timesteps_spacing=params.timesteps_spacing,
|
||||||
|
),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
first_inference_step=first_inference_step,
|
||||||
timesteps_spacing=timesteps_spacing,
|
params=params,
|
||||||
timesteps_offset=timesteps_offset,
|
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
|
||||||
noise_schedule=noise_schedule,
|
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
@ -30,18 +31,64 @@ class TimestepSpacing(str, Enum):
|
||||||
See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
See [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
LINSPACE_FLOAT: Sample N steps with linear interpolation, return a floating-point tensor.
|
LINSPACE: Sample N steps with linear interpolation, return a floating-point tensor.
|
||||||
LINSPACE_INT: Same as LINSPACE_FLOAT but return an integer tensor with rounded timesteps.
|
LINSPACE_ROUNDED: Same as LINSPACE but return an integer tensor with rounded timesteps.
|
||||||
LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM.
|
LEADING: Sample N+1 steps, do not include the last timestep (i.e. bad - non-zero SNR). Used in DDIM.
|
||||||
TRAILING: Sample N+1 steps, do not include the first timestep.
|
TRAILING: Sample N+1 steps, do not include the first timestep.
|
||||||
TRAILING_ALT: Variant of TRAILING used in DPM.
|
CUSTOM: Use custom timespacing in solver (override `_generate_timesteps`, see DPM).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LINSPACE_FLOAT = "linspace_float"
|
LINSPACE = "linspace"
|
||||||
LINSPACE_INT = "linspace_int"
|
LINSPACE_ROUNDED = "linspace_rounded"
|
||||||
LEADING = "leading"
|
LEADING = "leading"
|
||||||
TRAILING = "trailing"
|
TRAILING = "trailing"
|
||||||
TRAILING_ALT = "trailing_alt"
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPredictionType(str, Enum):
|
||||||
|
"""An enumeration of possible outputs of the model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
NOISE: The model predicts the noise (epsilon).
|
||||||
|
SAMPLE: The model predicts the denoised sample (x0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
NOISE = "noise"
|
||||||
|
SAMPLE = "sample"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||||
|
class SolverParams:
|
||||||
|
"""Common parameters for solvers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
||||||
|
timesteps_spacing: The spacing to use for the timesteps.
|
||||||
|
timesteps_offset: The offset to use for the timesteps.
|
||||||
|
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
|
||||||
|
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
|
||||||
|
noise_schedule: The noise schedule used to sample the noise schedule.
|
||||||
|
model_prediction_type: Defines what the model predicts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_train_timesteps: int | None = None
|
||||||
|
timesteps_spacing: TimestepSpacing | None = None
|
||||||
|
timesteps_offset: int | None = None
|
||||||
|
initial_diffusion_rate: float | None = None
|
||||||
|
final_diffusion_rate: float | None = None
|
||||||
|
noise_schedule: NoiseSchedule | None = None
|
||||||
|
model_prediction_type: ModelPredictionType | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||||
|
class ResolvedSolverParams(SolverParams):
|
||||||
|
num_train_timesteps: int
|
||||||
|
timesteps_spacing: TimestepSpacing
|
||||||
|
timesteps_offset: int
|
||||||
|
initial_diffusion_rate: float
|
||||||
|
final_diffusion_rate: float
|
||||||
|
noise_schedule: NoiseSchedule
|
||||||
|
model_prediction_type: ModelPredictionType
|
||||||
|
|
||||||
|
|
||||||
class Solver(fl.Module, ABC):
|
class Solver(fl.Module, ABC):
|
||||||
|
@ -55,17 +102,23 @@ class Solver(fl.Module, ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
timesteps: Tensor
|
timesteps: Tensor
|
||||||
|
params: ResolvedSolverParams
|
||||||
|
|
||||||
|
default_params = ResolvedSolverParams(
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
timesteps_spacing=TimestepSpacing.LINSPACE,
|
||||||
|
timesteps_offset=0,
|
||||||
|
initial_diffusion_rate=8.5e-4,
|
||||||
|
final_diffusion_rate=1.2e-2,
|
||||||
|
noise_schedule=NoiseSchedule.QUADRATIC,
|
||||||
|
model_prediction_type=ModelPredictionType.NOISE,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
num_train_timesteps: int = 1_000,
|
|
||||||
timesteps_spacing: TimestepSpacing = TimestepSpacing.LINSPACE_FLOAT,
|
|
||||||
timesteps_offset: int = 0,
|
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
|
||||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
|
params: SolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = float32,
|
dtype: DType = float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -73,32 +126,33 @@ class Solver(fl.Module, ABC):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps: The number of inference steps to perform.
|
num_inference_steps: The number of inference steps to perform.
|
||||||
num_train_timesteps: The number of timesteps used to train the diffusion process.
|
|
||||||
timesteps_spacing: The spacing to use for the timesteps.
|
|
||||||
timesteps_offset: The offset to use for the timesteps.
|
|
||||||
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
|
|
||||||
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
|
|
||||||
noise_schedule: The noise schedule used to sample the noise schedule.
|
|
||||||
first_inference_step: The first inference step to perform.
|
first_inference_step: The first inference step to perform.
|
||||||
|
params: The common parameters for solvers.
|
||||||
device: The PyTorch device to use for the solver's tensors.
|
device: The PyTorch device to use for the solver's tensors.
|
||||||
dtype: The PyTorch data type to use for the solver's tensors.
|
dtype: The PyTorch data type to use for the solver's tensors.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
self.num_train_timesteps = num_train_timesteps
|
|
||||||
self.timesteps_spacing = timesteps_spacing
|
|
||||||
self.timesteps_offset = timesteps_offset
|
|
||||||
self.initial_diffusion_rate = initial_diffusion_rate
|
|
||||||
self.final_diffusion_rate = final_diffusion_rate
|
|
||||||
self.noise_schedule = noise_schedule
|
|
||||||
self.first_inference_step = first_inference_step
|
self.first_inference_step = first_inference_step
|
||||||
|
self.params = self.resolve_params(params)
|
||||||
|
|
||||||
self.scale_factors = self.sample_noise_schedule()
|
self.scale_factors = self.sample_noise_schedule()
|
||||||
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
||||||
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||||
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
||||||
self.timesteps = self._generate_timesteps()
|
self.timesteps = self._generate_timesteps()
|
||||||
|
|
||||||
self.to(device=device, dtype=dtype)
|
self.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams:
|
||||||
|
if params is None:
|
||||||
|
return dataclasses.replace(self.default_params)
|
||||||
|
return dataclasses.replace(
|
||||||
|
self.default_params,
|
||||||
|
**{k: v for k, v in dataclasses.asdict(params).items() if v is not None},
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
"""Apply a step of the diffusion process using the Solver.
|
"""Apply a step of the diffusion process using the Solver.
|
||||||
|
@ -131,9 +185,9 @@ class Solver(fl.Module, ABC):
|
||||||
"""
|
"""
|
||||||
max_timestep = num_train_timesteps - 1 + offset
|
max_timestep = num_train_timesteps - 1 + offset
|
||||||
match spacing:
|
match spacing:
|
||||||
case TimestepSpacing.LINSPACE_FLOAT:
|
case TimestepSpacing.LINSPACE:
|
||||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
|
return tensor(np.linspace(offset, max_timestep, num_inference_steps), dtype=float32).flip(0)
|
||||||
case TimestepSpacing.LINSPACE_INT:
|
case TimestepSpacing.LINSPACE_ROUNDED:
|
||||||
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
|
return tensor(np.linspace(offset, max_timestep, num_inference_steps).round().astype(int)).flip(0)
|
||||||
case TimestepSpacing.LEADING:
|
case TimestepSpacing.LEADING:
|
||||||
step_ratio = num_train_timesteps // num_inference_steps
|
step_ratio = num_train_timesteps // num_inference_steps
|
||||||
|
@ -142,20 +196,15 @@ class Solver(fl.Module, ABC):
|
||||||
step_ratio = num_train_timesteps // num_inference_steps
|
step_ratio = num_train_timesteps // num_inference_steps
|
||||||
max_timestep = num_train_timesteps - 1 + offset
|
max_timestep = num_train_timesteps - 1 + offset
|
||||||
return arange(max_timestep, offset, -step_ratio)
|
return arange(max_timestep, offset, -step_ratio)
|
||||||
case TimestepSpacing.TRAILING_ALT:
|
case TimestepSpacing.CUSTOM:
|
||||||
# We use numpy here because:
|
raise RuntimeError("generate_timesteps called with custom spacing")
|
||||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
|
||||||
# torch.linspace(0,999,31)[15] is 499.5
|
|
||||||
# and we want the same result as the original DPM codebase.
|
|
||||||
np_space = np.linspace(offset, max_timestep, num_inference_steps + 1).round().astype(int)[1:]
|
|
||||||
return tensor(np_space).flip(0)
|
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
return self.generate_timesteps(
|
return self.generate_timesteps(
|
||||||
spacing=self.timesteps_spacing,
|
spacing=self.params.timesteps_spacing,
|
||||||
num_inference_steps=self.num_inference_steps,
|
num_inference_steps=self.num_inference_steps,
|
||||||
num_train_timesteps=self.num_train_timesteps,
|
num_train_timesteps=self.params.num_train_timesteps,
|
||||||
offset=self.timesteps_offset,
|
offset=self.params.timesteps_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_noise(
|
def add_noise(
|
||||||
|
@ -239,15 +288,10 @@ class Solver(fl.Module, ABC):
|
||||||
Returns:
|
Returns:
|
||||||
A new solver instance with the specified parameters.
|
A new solver instance with the specified parameters.
|
||||||
"""
|
"""
|
||||||
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
|
||||||
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=self.num_inference_steps if num_inference_steps is None else num_inference_steps,
|
||||||
num_train_timesteps=self.num_train_timesteps,
|
first_inference_step=self.first_inference_step if first_inference_step is None else first_inference_step,
|
||||||
initial_diffusion_rate=self.initial_diffusion_rate,
|
params=dataclasses.replace(self.params),
|
||||||
final_diffusion_rate=self.final_diffusion_rate,
|
|
||||||
noise_schedule=self.noise_schedule,
|
|
||||||
first_inference_step=first_inference_step,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
@ -282,9 +326,9 @@ class Solver(fl.Module, ABC):
|
||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
linspace(
|
linspace(
|
||||||
start=self.initial_diffusion_rate ** (1 / power),
|
start=self.params.initial_diffusion_rate ** (1 / power),
|
||||||
end=self.final_diffusion_rate ** (1 / power),
|
end=self.params.final_diffusion_rate ** (1 / power),
|
||||||
steps=self.num_train_timesteps,
|
steps=self.params.num_train_timesteps,
|
||||||
)
|
)
|
||||||
** power
|
** power
|
||||||
)
|
)
|
||||||
|
@ -295,7 +339,7 @@ class Solver(fl.Module, ABC):
|
||||||
Returns:
|
Returns:
|
||||||
A tensor representing the noise schedule.
|
A tensor representing the noise schedule.
|
||||||
"""
|
"""
|
||||||
match self.noise_schedule:
|
match self.params.noise_schedule:
|
||||||
case "uniform":
|
case "uniform":
|
||||||
return 1 - self.sample_power_distribution(1)
|
return 1 - self.sample_power_distribution(1)
|
||||||
case "quadratic":
|
case "quadratic":
|
||||||
|
@ -303,7 +347,7 @@ class Solver(fl.Module, ABC):
|
||||||
case "karras":
|
case "karras":
|
||||||
return 1 - self.sample_power_distribution(7)
|
return 1 - self.sample_power_distribution(7)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
raise ValueError(f"Unknown noise schedule: {self.params.noise_schedule}")
|
||||||
|
|
||||||
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
|
||||||
"""Move the solver to the specified device and data type.
|
"""Move the solver to the specified device and data type.
|
||||||
|
|
|
@ -27,7 +27,7 @@ from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
||||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule
|
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
|
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
|
||||||
|
@ -600,7 +600,7 @@ def sd15_ddim_karras(
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
ddim_solver = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
|
ddim_solver = DDIM(num_inference_steps=20, params=SolverParams(noise_schedule=NoiseSchedule.KARRAS))
|
||||||
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
sd15 = StableDiffusion_1(solver=ddim_solver, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||||
|
|
|
@ -11,8 +11,10 @@ from refiners.foundationals.latent_diffusion.solvers import (
|
||||||
DPMSolver,
|
DPMSolver,
|
||||||
Euler,
|
Euler,
|
||||||
LCMSolver,
|
LCMSolver,
|
||||||
|
ModelPredictionType,
|
||||||
NoiseSchedule,
|
NoiseSchedule,
|
||||||
Solver,
|
Solver,
|
||||||
|
SolverParams,
|
||||||
TimestepSpacing,
|
TimestepSpacing,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,7 +43,10 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
||||||
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
final_sigmas_type="sigma_min", # default before Diffusers 0.26.0
|
||||||
)
|
)
|
||||||
diffusers_scheduler.set_timesteps(n_steps)
|
diffusers_scheduler.set_timesteps(n_steps)
|
||||||
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
|
refiners_scheduler = DPMSolver(
|
||||||
|
num_inference_steps=n_steps,
|
||||||
|
last_step_first_order=last_step_first_order,
|
||||||
|
)
|
||||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
sample = randn(1, 3, 32, 32)
|
sample = randn(1, 3, 32, 32)
|
||||||
|
@ -80,10 +85,12 @@ def test_ddim_diffusers():
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
def test_euler_diffusers():
|
@pytest.mark.parametrize("model_prediction_type", [ModelPredictionType.NOISE, ModelPredictionType.SAMPLE])
|
||||||
|
def test_euler_diffusers(model_prediction_type: ModelPredictionType):
|
||||||
from diffusers import EulerDiscreteScheduler # type: ignore
|
from diffusers import EulerDiscreteScheduler # type: ignore
|
||||||
|
|
||||||
manual_seed(0)
|
manual_seed(0)
|
||||||
|
diffusers_prediction_type = "epsilon" if model_prediction_type == ModelPredictionType.NOISE else "sample"
|
||||||
diffusers_scheduler = EulerDiscreteScheduler(
|
diffusers_scheduler = EulerDiscreteScheduler(
|
||||||
beta_end=0.012,
|
beta_end=0.012,
|
||||||
beta_schedule="scaled_linear",
|
beta_schedule="scaled_linear",
|
||||||
|
@ -92,9 +99,10 @@ def test_euler_diffusers():
|
||||||
steps_offset=1,
|
steps_offset=1,
|
||||||
timestep_spacing="linspace",
|
timestep_spacing="linspace",
|
||||||
use_karras_sigmas=False,
|
use_karras_sigmas=False,
|
||||||
|
prediction_type=diffusers_prediction_type,
|
||||||
)
|
)
|
||||||
diffusers_scheduler.set_timesteps(30)
|
diffusers_scheduler.set_timesteps(30)
|
||||||
refiners_scheduler = Euler(num_inference_steps=30)
|
refiners_scheduler = Euler(num_inference_steps=30, params=SolverParams(model_prediction_type=model_prediction_type))
|
||||||
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
assert equal(refiners_scheduler.timesteps, diffusers_scheduler.timesteps)
|
||||||
|
|
||||||
sample = randn(1, 4, 32, 32)
|
sample = randn(1, 4, 32, 32)
|
||||||
|
@ -192,16 +200,20 @@ def test_solver_device(test_device: Device):
|
||||||
|
|
||||||
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
||||||
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
||||||
scheduler = DDIM(num_inference_steps=30, device=test_device, noise_schedule=noise_schedule)
|
scheduler = DDIM(
|
||||||
|
num_inference_steps=30,
|
||||||
|
params=SolverParams(noise_schedule=noise_schedule),
|
||||||
|
device=test_device,
|
||||||
|
)
|
||||||
assert len(scheduler.scale_factors) == 1000
|
assert len(scheduler.scale_factors) == 1000
|
||||||
assert scheduler.scale_factors[0] == 1 - scheduler.initial_diffusion_rate
|
assert scheduler.scale_factors[0] == 1 - scheduler.params.initial_diffusion_rate
|
||||||
assert scheduler.scale_factors[-1] == 1 - scheduler.final_diffusion_rate
|
assert scheduler.scale_factors[-1] == 1 - scheduler.params.final_diffusion_rate
|
||||||
|
|
||||||
|
|
||||||
def test_solver_timestep_spacing():
|
def test_solver_timestep_spacing():
|
||||||
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
# Tests we get the results from [[arXiv:2305.08891] Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) table 2.
|
||||||
linspace_int = Solver.generate_timesteps(
|
linspace_int = Solver.generate_timesteps(
|
||||||
spacing=TimestepSpacing.LINSPACE_INT,
|
spacing=TimestepSpacing.LINSPACE_ROUNDED,
|
||||||
num_inference_steps=10,
|
num_inference_steps=10,
|
||||||
num_train_timesteps=1000,
|
num_train_timesteps=1000,
|
||||||
offset=1,
|
offset=1,
|
||||||
|
|
Loading…
Reference in a new issue