(doc/fluxion/ld) add Solver docstrings

This commit is contained in:
Laurent 2024-02-02 10:54:59 +00:00 committed by Laureηt
parent 289261f2fb
commit 0c5a7a8269

View file

@ -10,16 +10,23 @@ T = TypeVar("T", bound="Solver")
class NoiseSchedule(str, Enum): class NoiseSchedule(str, Enum):
"""An enumeration of noise schedules used to sample the noise schedule.
Attributes:
UNIFORM: A uniform noise schedule.
QUADRATIC: A quadratic noise schedule.
KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364)
"""
UNIFORM = "uniform" UNIFORM = "uniform"
QUADRATIC = "quadratic" QUADRATIC = "quadratic"
KARRAS = "karras" KARRAS = "karras"
class Solver(fl.Module, ABC): class Solver(fl.Module, ABC):
""" """The base class for creating a diffusion model solver.
A base class for creating a diffusion model solver.
Solver creates a sequence of noise and scaling factors used in the diffusion process, Solvers create a sequence of noise and scaling factors used in the diffusion process,
which gradually transforms the original data distribution into a Gaussian one. which gradually transforms the original data distribution into a Gaussian one.
This process is described using several parameters such as initial and final diffusion rates, This process is described using several parameters such as initial and final diffusion rates,
@ -39,6 +46,18 @@ class Solver(fl.Module, ABC):
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = float32, dtype: DType = float32,
) -> None: ) -> None:
"""Initializes a new `Solver` instance.
Args:
num_inference_steps: The number of inference steps to perform.
num_train_timesteps: The number of timesteps used to train the diffusion process.
initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule.
final_diffusion_rate: The final diffusion rate used to sample the noise schedule.
noise_schedule: The noise schedule used to sample the noise schedule.
first_inference_step: The first inference step to perform.
device: The PyTorch device to use for the solver's tensors.
dtype: The PyTorch data type to use for the solver's tensors.
"""
super().__init__() super().__init__()
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps self.num_train_timesteps = num_train_timesteps
@ -55,18 +74,24 @@ class Solver(fl.Module, ABC):
@abstractmethod @abstractmethod
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
""" """Apply a step of the diffusion process using the Solver.
Applies a step of the diffusion process to the input tensor `x` using the provided `predicted_noise` and `timestep`.
Note:
This method should be overridden by subclasses to implement the specific diffusion process. This method should be overridden by subclasses to implement the specific diffusion process.
Args:
x: The input tensor to apply the diffusion process to.
predicted_noise: The predicted noise tensor for the current step.
step: The current step of the diffusion process.
generator: The random number generator to use for sampling noise.
""" """
... ...
@abstractmethod @abstractmethod
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> Tensor:
""" """Generate a tensor of timesteps.
Generates a tensor of timesteps.
Note:
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process. This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
""" """
... ...
@ -77,6 +102,16 @@ class Solver(fl.Module, ABC):
noise: Tensor, noise: Tensor,
step: int, step: int,
) -> Tensor: ) -> Tensor:
"""Add noise to the input tensor using the solver's parameters.
Args:
x: The input tensor to add noise to.
noise: The noise tensor to add to the input tensor.
step: The current step of the diffusion process.
Returns:
The input tensor with added noise.
"""
timestep = self.timesteps[step] timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep] cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep] noise_stds = self.noise_std[timestep]
@ -84,28 +119,43 @@ class Solver(fl.Module, ABC):
return noised_x return noised_x
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""Remove noise from the input tensor using the current step of the diffusion process.
Args:
x: The input tensor to remove noise from.
noise: The noise tensor to remove from the input tensor.
step: The current step of the diffusion process.
Returns:
The denoised input tensor.
"""
timestep = self.timesteps[step] timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep] cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep] noise_stds = self.noise_std[timestep]
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like # See equation (15) from https://arxiv.org/pdf/2006.11239.pdf.
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance) # Useful to preview progress or for guidance
# See also https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x return denoised_x
@property @property
def all_steps(self) -> list[int]: def all_steps(self) -> list[int]:
"""Return a list of all inference steps."""
return list(range(self.num_inference_steps)) return list(range(self.num_inference_steps))
@property @property
def inference_steps(self) -> list[int]: def inference_steps(self) -> list[int]:
"""Return a list of inference steps to perform."""
return self.all_steps[self.first_inference_step :] return self.all_steps[self.first_inference_step :]
@property @property
def device(self) -> Device: def device(self) -> Device:
"""The PyTorch device used for the solver's tensors."""
return self.scale_factors.device return self.scale_factors.device
@property @property
def dtype(self) -> DType: def dtype(self) -> DType:
"""The PyTorch data type used for the solver's tensors."""
return self.scale_factors.dtype return self.scale_factors.dtype
@device.setter @device.setter
@ -117,6 +167,15 @@ class Solver(fl.Module, ABC):
self.to(dtype=dtype) self.to(dtype=dtype)
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T: def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
"""Rebuild the solver with new parameters.
Args:
num_inference_steps: The number of inference steps to perform.
first_inference_step: The first inference step to perform.
Returns:
A new solver instance with the specified parameters.
"""
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
return self.__class__( return self.__class__(
@ -131,12 +190,30 @@ class Solver(fl.Module, ABC):
) )
def scale_model_input(self, x: Tensor, step: int) -> Tensor: def scale_model_input(self, x: Tensor, step: int) -> Tensor:
""" """Scale the model's input according to the current timestep.
For compatibility with solvers that need to scale the input according to the current timestep.
Note:
This method should only be overridden by solvers that
need to scale the input according to the current timestep.
Args:
x: The input tensor to scale.
step: The current step of the diffusion process.
Returns:
The scaled input tensor.
""" """
return x return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor: def sample_power_distribution(self, power: float = 2, /) -> Tensor:
"""Sample a power distribution.
Args:
power: The power to use for the distribution.
Returns:
A tensor representing the power distribution between the initial and final diffusion rates of the solver.
"""
return ( return (
linspace( linspace(
start=self.initial_diffusion_rate ** (1 / power), start=self.initial_diffusion_rate ** (1 / power),
@ -147,6 +224,11 @@ class Solver(fl.Module, ABC):
) )
def sample_noise_schedule(self) -> Tensor: def sample_noise_schedule(self) -> Tensor:
"""Sample the noise schedule.
Returns:
A tensor representing the noise schedule.
"""
match self.noise_schedule: match self.noise_schedule:
case "uniform": case "uniform":
return 1 - self.sample_power_distribution(1) return 1 - self.sample_power_distribution(1)
@ -158,6 +240,15 @@ class Solver(fl.Module, ABC):
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}") raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver": def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver":
"""Move the solver to the specified device and data type.
Args:
device: The PyTorch device to move the solver to.
dtype: The PyTorch data type to move the solver to.
Returns:
The solver instance, moved to the specified device and data type.
"""
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]: for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]:
match name: match name: