mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(doc/fluxion/ld) add DDPM
, DDIM
, DPM++
and Euleur
docstrings
This commit is contained in:
parent
6d8016190c
commit
fc7b4dd62d
|
@ -4,6 +4,11 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
||||||
|
|
||||||
|
|
||||||
class DDIM(Solver):
|
class DDIM(Solver):
|
||||||
|
"""Denoising Diffusion Implicit Model (DDIM) solver.
|
||||||
|
|
||||||
|
See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
|
@ -15,6 +20,18 @@ class DDIM(Solver):
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initializes a new DDIM solver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps: The number of inference steps.
|
||||||
|
num_train_timesteps: The number of training timesteps.
|
||||||
|
initial_diffusion_rate: The initial diffusion rate.
|
||||||
|
final_diffusion_rate: The final diffusion rate.
|
||||||
|
noise_schedule: The noise schedule.
|
||||||
|
first_inference_step: The first inference step.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
from torch import Generator, Tensor, arange, device as Device, dtype as DType
|
from torch import Generator, Tensor, arange, device as Device
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
|
|
||||||
|
|
||||||
class DDPM(Solver):
|
class DDPM(Solver):
|
||||||
"""
|
"""Denoising Diffusion Probabilistic Model (DDPM) solver.
|
||||||
Denoising Diffusion Probabilistic Model
|
|
||||||
|
|
||||||
Only used for training Latent Diffusion models. Cannot be called.
|
Warning:
|
||||||
|
Only used for training Latent Diffusion models.
|
||||||
|
Cannot be called.
|
||||||
|
|
||||||
|
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -16,11 +19,19 @@ class DDPM(Solver):
|
||||||
num_train_timesteps: int = 1_000,
|
num_train_timesteps: int = 1_000,
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
noise_schedule: NoiseSchedule | None = None, # ignored
|
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType | None = None, # ignored
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initializes a new DDPM solver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps: The number of inference steps.
|
||||||
|
num_train_timesteps: The number of training timesteps.
|
||||||
|
initial_diffusion_rate: The initial diffusion rate.
|
||||||
|
final_diffusion_rate: The final diffusion rate.
|
||||||
|
first_inference_step: The first inference step.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
num_train_timesteps=num_train_timesteps,
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
|
|
@ -7,13 +7,16 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
||||||
|
|
||||||
|
|
||||||
class DPMSolver(Solver):
|
class DPMSolver(Solver):
|
||||||
"""
|
"""Diffusion probabilistic models (DPMs) solver.
|
||||||
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
|
||||||
|
|
||||||
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
|
See [[arXiv:2211.01095] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://arxiv.org/abs/2211.01095)
|
||||||
when used with SDXL and few steps. This parameter is a way to mitigate that
|
for more details.
|
||||||
effect by using a first-order (Euler) update instead of a second-order update
|
|
||||||
for the last step of the diffusion.
|
Note:
|
||||||
|
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
|
||||||
|
when used with SDXL and few steps. This parameter is a way to mitigate that
|
||||||
|
effect by using a first-order (Euler) update instead of a second-order update
|
||||||
|
for the last step of the diffusion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -42,10 +45,16 @@ class DPMSolver(Solver):
|
||||||
self.last_step_first_order = last_step_first_order
|
self.last_step_first_order = last_step_first_order
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
# We need to use numpy here because:
|
"""Generate the timesteps used by the solver.
|
||||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
|
||||||
# torch.linspace(0,999,31)[15] is 499.5
|
Note:
|
||||||
# ...and we want the same result as the original codebase.
|
We need to use numpy here because:
|
||||||
|
|
||||||
|
- numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||||
|
- torch.linspace(0,999,31)[15] is 499.5
|
||||||
|
|
||||||
|
and we want the same result as the original codebase.
|
||||||
|
"""
|
||||||
return tensor(
|
return tensor(
|
||||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
||||||
).flip(0)
|
).flip(0)
|
||||||
|
@ -55,6 +64,12 @@ class DPMSolver(Solver):
|
||||||
num_inference_steps: int | None,
|
num_inference_steps: int | None,
|
||||||
first_inference_step: int | None = None,
|
first_inference_step: int | None = None,
|
||||||
) -> "DPMSolver":
|
) -> "DPMSolver":
|
||||||
|
"""Rebuilds the solver with new parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps: The number of inference steps.
|
||||||
|
first_inference_step: The first inference step.
|
||||||
|
"""
|
||||||
r = super().rebuild(
|
r = super().rebuild(
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
first_inference_step=first_inference_step,
|
first_inference_step=first_inference_step,
|
||||||
|
@ -63,6 +78,16 @@ class DPMSolver(Solver):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
"""Applies a first-order backward Euler update to the input data `x`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input data.
|
||||||
|
noise: The predicted noise.
|
||||||
|
step: The current step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The denoised version of the input data `x`.
|
||||||
|
"""
|
||||||
current_timestep = self.timesteps[step]
|
current_timestep = self.timesteps[step]
|
||||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||||
|
|
||||||
|
@ -79,6 +104,15 @@ class DPMSolver(Solver):
|
||||||
return denoised_x
|
return denoised_x
|
||||||
|
|
||||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
"""Applies a second-order backward Euler update to the input data `x`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input data.
|
||||||
|
step: The current step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The denoised version of the input data `x`.
|
||||||
|
"""
|
||||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||||
current_timestep = self.timesteps[step]
|
current_timestep = self.timesteps[step]
|
||||||
next_timestep = self.timesteps[step - 1]
|
next_timestep = self.timesteps[step - 1]
|
||||||
|
@ -106,12 +140,21 @@ class DPMSolver(Solver):
|
||||||
return denoised_x
|
return denoised_x
|
||||||
|
|
||||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
"""
|
"""Apply one step of the backward diffusion process.
|
||||||
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
|
||||||
|
|
||||||
This method works by estimating the denoised version of `x` and applying either a first-order or second-order
|
Note:
|
||||||
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
|
This method works by estimating the denoised version of `x` and applying either a first-order or second-order
|
||||||
(ODEs).
|
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
|
||||||
|
(ODEs).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input data.
|
||||||
|
predicted_noise: The predicted noise.
|
||||||
|
step: The current step.
|
||||||
|
generator: The random number generator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The denoised version of the input data `x`.
|
||||||
"""
|
"""
|
||||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,12 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
||||||
|
|
||||||
|
|
||||||
class Euler(Solver):
|
class Euler(Solver):
|
||||||
|
"""Euler solver.
|
||||||
|
|
||||||
|
See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364)
|
||||||
|
for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
|
@ -17,6 +23,18 @@ class Euler(Solver):
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
):
|
):
|
||||||
|
"""Initializes a new Euler solver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps: The number of inference steps.
|
||||||
|
num_train_timesteps: The number of training timesteps.
|
||||||
|
initial_diffusion_rate: The initial diffusion rate.
|
||||||
|
final_diffusion_rate: The final diffusion rate.
|
||||||
|
noise_schedule: The noise schedule.
|
||||||
|
first_inference_step: The first inference step.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
if noise_schedule != NoiseSchedule.QUADRATIC:
|
if noise_schedule != NoiseSchedule.QUADRATIC:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -33,23 +51,40 @@ class Euler(Solver):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def init_noise_sigma(self) -> Tensor:
|
def init_noise_sigma(self) -> Tensor:
|
||||||
|
"""The initial noise sigma."""
|
||||||
return self.sigmas.max()
|
return self.sigmas.max()
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
# We need to use numpy here because:
|
"""Generate the timesteps used by the solver.
|
||||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
|
||||||
# torch.linspace(0,999,31)[15] is 499.5
|
Note:
|
||||||
# ...and we want the same result as the original codebase.
|
We need to use numpy here because:
|
||||||
|
|
||||||
|
- numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||||
|
- torch.linspace(0,999,31)[15] is 499.5
|
||||||
|
|
||||||
|
and we want the same result as the original codebase.
|
||||||
|
"""
|
||||||
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
|
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
|
||||||
return timesteps
|
return timesteps
|
||||||
|
|
||||||
def _generate_sigmas(self) -> Tensor:
|
def _generate_sigmas(self) -> Tensor:
|
||||||
|
"""Generate the sigmas used by the solver."""
|
||||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
||||||
sigmas = torch.cat([sigmas, tensor([0.0])])
|
sigmas = torch.cat([sigmas, tensor([0.0])])
|
||||||
return sigmas.to(device=self.device, dtype=self.dtype)
|
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
"""Scales the model input according to the current step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The model input.
|
||||||
|
step: The current step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The scaled model input.
|
||||||
|
"""
|
||||||
sigma = self.sigmas[step]
|
sigma = self.sigmas[step]
|
||||||
return x / ((sigma**2 + 1) ** 0.5)
|
return x / ((sigma**2 + 1) ** 0.5)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue