mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
(doc/fluxion/ld) add DDPM
, DDIM
, DPM++
and Euleur
docstrings
This commit is contained in:
parent
24c11745cd
commit
ff17991261
|
@ -4,6 +4,11 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
|||
|
||||
|
||||
class DDIM(Solver):
|
||||
"""Denoising Diffusion Implicit Model (DDIM) solver.
|
||||
|
||||
See [[arXiv:2010.02502] Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
|
@ -15,6 +20,18 @@ class DDIM(Solver):
|
|||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
) -> None:
|
||||
"""Initializes a new DDIM solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
noise_schedule: The noise schedule.
|
||||
first_inference_step: The first inference step.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
from torch import Generator, Tensor, arange, device as Device, dtype as DType
|
||||
from torch import Generator, Tensor, arange, device as Device
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||
|
||||
|
||||
class DDPM(Solver):
|
||||
"""
|
||||
Denoising Diffusion Probabilistic Model
|
||||
"""Denoising Diffusion Probabilistic Model (DDPM) solver.
|
||||
|
||||
Only used for training Latent Diffusion models. Cannot be called.
|
||||
Warning:
|
||||
Only used for training Latent Diffusion models.
|
||||
Cannot be called.
|
||||
|
||||
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -16,11 +19,19 @@ class DDPM(Solver):
|
|||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule | None = None, # ignored
|
||||
first_inference_step: int = 0,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType | None = None, # ignored
|
||||
) -> None:
|
||||
"""Initializes a new DDPM solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
first_inference_step: The first inference step.
|
||||
device: The PyTorch device to use.
|
||||
"""
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
|
|
|
@ -7,9 +7,12 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
|||
|
||||
|
||||
class DPMSolver(Solver):
|
||||
"""
|
||||
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||
"""Diffusion probabilistic models (DPMs) solver.
|
||||
|
||||
See [[arXiv:2211.01095] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://arxiv.org/abs/2211.01095)
|
||||
for more details.
|
||||
|
||||
Note:
|
||||
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
|
||||
when used with SDXL and few steps. This parameter is a way to mitigate that
|
||||
effect by using a first-order (Euler) update instead of a second-order update
|
||||
|
@ -42,10 +45,16 @@ class DPMSolver(Solver):
|
|||
self.last_step_first_order = last_step_first_order
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
# We need to use numpy here because:
|
||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
# torch.linspace(0,999,31)[15] is 499.5
|
||||
# ...and we want the same result as the original codebase.
|
||||
"""Generate the timesteps used by the solver.
|
||||
|
||||
Note:
|
||||
We need to use numpy here because:
|
||||
|
||||
- numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
- torch.linspace(0,999,31)[15] is 499.5
|
||||
|
||||
and we want the same result as the original codebase.
|
||||
"""
|
||||
return tensor(
|
||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
||||
).flip(0)
|
||||
|
@ -55,6 +64,12 @@ class DPMSolver(Solver):
|
|||
num_inference_steps: int | None,
|
||||
first_inference_step: int | None = None,
|
||||
) -> "DPMSolver":
|
||||
"""Rebuilds the solver with new parameters.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
first_inference_step: The first inference step.
|
||||
"""
|
||||
r = super().rebuild(
|
||||
num_inference_steps=num_inference_steps,
|
||||
first_inference_step=first_inference_step,
|
||||
|
@ -63,6 +78,16 @@ class DPMSolver(Solver):
|
|||
return r
|
||||
|
||||
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
"""Applies a first-order backward Euler update to the input data `x`.
|
||||
|
||||
Args:
|
||||
x: The input data.
|
||||
noise: The predicted noise.
|
||||
step: The current step.
|
||||
|
||||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
"""
|
||||
current_timestep = self.timesteps[step]
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||
|
||||
|
@ -79,6 +104,15 @@ class DPMSolver(Solver):
|
|||
return denoised_x
|
||||
|
||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
||||
"""Applies a second-order backward Euler update to the input data `x`.
|
||||
|
||||
Args:
|
||||
x: The input data.
|
||||
step: The current step.
|
||||
|
||||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
"""
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||
current_timestep = self.timesteps[step]
|
||||
next_timestep = self.timesteps[step - 1]
|
||||
|
@ -106,12 +140,21 @@ class DPMSolver(Solver):
|
|||
return denoised_x
|
||||
|
||||
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""
|
||||
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
||||
"""Apply one step of the backward diffusion process.
|
||||
|
||||
Note:
|
||||
This method works by estimating the denoised version of `x` and applying either a first-order or second-order
|
||||
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
|
||||
(ODEs).
|
||||
|
||||
Args:
|
||||
x: The input data.
|
||||
predicted_noise: The predicted noise.
|
||||
step: The current step.
|
||||
generator: The random number generator.
|
||||
|
||||
Returns:
|
||||
The denoised version of the input data `x`.
|
||||
"""
|
||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||
|
||||
|
|
|
@ -6,6 +6,12 @@ from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule
|
|||
|
||||
|
||||
class Euler(Solver):
|
||||
"""Euler solver.
|
||||
|
||||
See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364)
|
||||
for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
|
@ -17,6 +23,18 @@ class Euler(Solver):
|
|||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
"""Initializes a new Euler solver.
|
||||
|
||||
Args:
|
||||
num_inference_steps: The number of inference steps.
|
||||
num_train_timesteps: The number of training timesteps.
|
||||
initial_diffusion_rate: The initial diffusion rate.
|
||||
final_diffusion_rate: The final diffusion rate.
|
||||
noise_schedule: The noise schedule.
|
||||
first_inference_step: The first inference step.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
if noise_schedule != NoiseSchedule.QUADRATIC:
|
||||
raise NotImplementedError
|
||||
super().__init__(
|
||||
|
@ -33,23 +51,40 @@ class Euler(Solver):
|
|||
|
||||
@property
|
||||
def init_noise_sigma(self) -> Tensor:
|
||||
"""The initial noise sigma."""
|
||||
return self.sigmas.max()
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
# We need to use numpy here because:
|
||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
# torch.linspace(0,999,31)[15] is 499.5
|
||||
# ...and we want the same result as the original codebase.
|
||||
"""Generate the timesteps used by the solver.
|
||||
|
||||
Note:
|
||||
We need to use numpy here because:
|
||||
|
||||
- numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
- torch.linspace(0,999,31)[15] is 499.5
|
||||
|
||||
and we want the same result as the original codebase.
|
||||
"""
|
||||
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
|
||||
return timesteps
|
||||
|
||||
def _generate_sigmas(self) -> Tensor:
|
||||
"""Generate the sigmas used by the solver."""
|
||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
||||
sigmas = torch.cat([sigmas, tensor([0.0])])
|
||||
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
"""Scales the model input according to the current step.
|
||||
|
||||
Args:
|
||||
x: The model input.
|
||||
step: The current step.
|
||||
|
||||
Returns:
|
||||
The scaled model input.
|
||||
"""
|
||||
sigma = self.sigmas[step]
|
||||
return x / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
|
|
Loading…
Reference in a new issue