(doc/fluxion/ld) add Solver docstrings

This commit is contained in:
Laurent 2024-02-02 10:54:59 +00:00 committed by Laureηt
parent 289261f2fb
commit 0c5a7a8269

View file

@ -10,16 +10,23 @@ T = TypeVar("T", bound="Solver")
class NoiseSchedule(str, Enum):
"""An enumeration of noise schedules used to sample the noise schedule.
Attributes:
UNIFORM: A uniform noise schedule.
QUADRATIC: A quadratic noise schedule.
KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364)
"""
UNIFORM = "uniform"
QUADRATIC = "quadratic"
KARRAS = "karras"
class Solver(fl.Module, ABC):
"""
A base class for creating a diffusion model solver.
"""The base class for creating a diffusion model solver.
Solver creates a sequence of noise and scaling factors used in the diffusion process,
Solvers create a sequence of noise and scaling factors used in the diffusion process,
which gradually transforms the original data distribution into a Gaussian one.
This process is described using several parameters such as initial and final diffusion rates,
@ -39,6 +46,18 @@ class Solver(fl.Module, ABC):
device: Device | str = "cpu",
dtype: DType = float32,
) -> None:
"""Initializes a new `Solver` instance.
Args:
num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of timesteps used to train the diffusion process.
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
noise_schedule: The noise schedule used to sample the noise schedule.
first_inference_step: The first inference step to perform.
device: The PyTorch device to use for the solver's tensors.
dtype: The PyTorch data type to use for the solver's tensors.
"""
super().__init__()
self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps
@ -55,19 +74,25 @@ class Solver(fl.Module, ABC):
@abstractmethod
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `predicted_noise` and `timestep`.
"""Apply a step of the diffusion process using the Solver.
This method should be overridden by subclasses to implement the specific diffusion process.
Note:
This method should be overridden by subclasses to implement the specific diffusion process.
Args:
x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step.
step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise.
"""
...
@abstractmethod
def _generate_timesteps(self) -> Tensor:
"""
Generates a tensor of timesteps.
"""Generate a tensor of timesteps.
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
Note:
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
"""
...
@ -77,6 +102,16 @@ class Solver(fl.Module, ABC):
noise: Tensor,
step: int,
) -> Tensor:
"""Add noise to the input tensor using the solver's parameters.
Args:
x: The input tensor to add noise to.
noise: The noise tensor to add to the input tensor.
step: The current step of the diffusion process.
Returns:
The input tensor with added noise.
"""
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
@ -84,28 +119,43 @@ class Solver(fl.Module, ABC):
return noised_x
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""Remove noise from the input tensor using the current step of the diffusion process.
Args:
x: The input tensor to remove noise from.
noise: The noise tensor to remove from the input tensor.
step: The current step of the diffusion process.
Returns:
The denoised input tensor.
"""
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf.
# Useful to preview progress or for guidance
# See also https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x
@property
def all_steps(self) -> list[int]:
"""Return a list of all inference steps."""
return list(range(self.num_inference_steps))
@property
def inference_steps(self) -> list[int]:
"""Return a list of inference steps to perform."""
return self.all_steps[self.first_inference_step :]
@property
def device(self) -> Device:
"""The PyTorch device used for the solver's tensors."""
return self.scale_factors.device
@property
def dtype(self) -> DType:
"""The PyTorch data type used for the solver's tensors."""
return self.scale_factors.dtype
@device.setter
@ -117,6 +167,15 @@ class Solver(fl.Module, ABC):
self.to(dtype=dtype)
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
"""Rebuild the solver with new parameters.
Args:
num_inference_steps: The number of inference steps to perform.
first_inference_step: The first inference step to perform.
Returns:
A new solver instance with the specified parameters.
"""
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
return self.__class__(
@ -131,12 +190,30 @@ class Solver(fl.Module, ABC):
)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with solvers that need to scale the input according to the current timestep.
"""Scale the model's input according to the current timestep.
Note:
This method should only be overridden by solvers that
need to scale the input according to the current timestep.
Args:
x: The input tensor to scale.
step: The current step of the diffusion process.
Returns:
The scaled input tensor.
"""
return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
"""Sample a power distribution.
Args:
power: The power to use for the distribution.
Returns:
A tensor representing the power distribution between the initial and final diffusion rates of the solver.
"""
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
@ -147,6 +224,11 @@ class Solver(fl.Module, ABC):
)
def sample_noise_schedule(self) -> Tensor:
"""Sample the noise schedule.
Returns:
A tensor representing the noise schedule.
"""
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
@ -158,6 +240,15 @@ class Solver(fl.Module, ABC):
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
"""Move the solver to the specified device and data type.
Args:
device: The PyTorch device to move the solver to.
dtype: The PyTorch data type to move the solver to.
Returns:
The solver instance, moved to the specified device and data type.
"""
super().to(device=device, dtype=dtype)
for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
match name: