From 7309a0985e4e014e26e21b94d64b4f8383ee2c7e Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 2 Feb 2024 10:54:59 +0000 Subject: [PATCH] (doc/fluxion/ld) add `Solver` docstrings --- .../latent_diffusion/solvers/solver.py | 117 ++++++++++++++++-- 1 file changed, 104 insertions(+), 13 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/solvers/solver.py b/src/refiners/foundationals/latent_diffusion/solvers/solver.py index 6579fa1..0bdc359 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/solver.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/solver.py @@ -10,16 +10,23 @@ T = TypeVar("T", bound="Solver") class NoiseSchedule(str, Enum): + """An enumeration of noise schedules used to sample the noise schedule. + + Attributes: + UNIFORM: A uniform noise schedule. + QUADRATIC: A quadratic noise schedule. + KARRAS: See [[arXiv:2206.00364] Elucidating the Design Space of Diffusion-Based Generative Models, Equation 5](https://arxiv.org/abs/2206.00364) + """ + UNIFORM = "uniform" QUADRATIC = "quadratic" KARRAS = "karras" class Solver(fl.Module, ABC): - """ - A base class for creating a diffusion model solver. + """The base class for creating a diffusion model solver. - Solver creates a sequence of noise and scaling factors used in the diffusion process, + Solvers create a sequence of noise and scaling factors used in the diffusion process, which gradually transforms the original data distribution into a Gaussian one. This process is described using several parameters such as initial and final diffusion rates, @@ -39,6 +46,18 @@ class Solver(fl.Module, ABC): device: Device | str = "cpu", dtype: DType = float32, ) -> None: + """Initializes a new `Solver` instance. + + Args: + num_inference_steps: The number of inference steps to perform. + num_train_timesteps: The number of timesteps used to train the diffusion process. + initial_diffusion_rate: The initial diffusion rate used to sample the noise schedule. + final_diffusion_rate: The final diffusion rate used to sample the noise schedule. + noise_schedule: The noise schedule used to sample the noise schedule. + first_inference_step: The first inference step to perform. + device: The PyTorch device to use for the solver's tensors. + dtype: The PyTorch data type to use for the solver's tensors. + """ super().__init__() self.num_inference_steps = num_inference_steps self.num_train_timesteps = num_train_timesteps @@ -55,19 +74,25 @@ class Solver(fl.Module, ABC): @abstractmethod def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: - """ - Applies a step of the diffusion process to the input tensor `x` using the provided `predicted_noise` and `timestep`. + """Apply a step of the diffusion process using the Solver. - This method should be overridden by subclasses to implement the specific diffusion process. + Note: + This method should be overridden by subclasses to implement the specific diffusion process. + + Args: + x: The input tensor to apply the diffusion process to. + predicted_noise: The predicted noise tensor for the current step. + step: The current step of the diffusion process. + generator: The random number generator to use for sampling noise. """ ... @abstractmethod def _generate_timesteps(self) -> Tensor: - """ - Generates a tensor of timesteps. + """Generate a tensor of timesteps. - This method should be overridden by subclasses to provide the specific timesteps for the diffusion process. + Note: + This method should be overridden by subclasses to provide the specific timesteps for the diffusion process. """ ... @@ -77,6 +102,16 @@ class Solver(fl.Module, ABC): noise: Tensor, step: int, ) -> Tensor: + """Add noise to the input tensor using the solver's parameters. + + Args: + x: The input tensor to add noise to. + noise: The noise tensor to add to the input tensor. + step: The current step of the diffusion process. + + Returns: + The input tensor with added noise. + """ timestep = self.timesteps[step] cumulative_scale_factors = self.cumulative_scale_factors[timestep] noise_stds = self.noise_std[timestep] @@ -84,28 +119,43 @@ class Solver(fl.Module, ABC): return noised_x def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + """Remove noise from the input tensor using the current step of the diffusion process. + + Args: + x: The input tensor to remove noise from. + noise: The noise tensor to remove from the input tensor. + step: The current step of the diffusion process. + + Returns: + The denoised input tensor. + """ timestep = self.timesteps[step] cumulative_scale_factors = self.cumulative_scale_factors[timestep] noise_stds = self.noise_std[timestep] - # See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like - # in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance) + # See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. + # Useful to preview progress or for guidance + # See also https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance) denoised_x = (x - noise_stds * noise) / cumulative_scale_factors return denoised_x @property def all_steps(self) -> list[int]: + """Return a list of all inference steps.""" return list(range(self.num_inference_steps)) @property def inference_steps(self) -> list[int]: + """Return a list of inference steps to perform.""" return self.all_steps[self.first_inference_step :] @property def device(self) -> Device: + """The PyTorch device used for the solver's tensors.""" return self.scale_factors.device @property def dtype(self) -> DType: + """The PyTorch data type used for the solver's tensors.""" return self.scale_factors.dtype @device.setter @@ -117,6 +167,15 @@ class Solver(fl.Module, ABC): self.to(dtype=dtype) def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T: + """Rebuild the solver with new parameters. + + Args: + num_inference_steps: The number of inference steps to perform. + first_inference_step: The first inference step to perform. + + Returns: + A new solver instance with the specified parameters. + """ num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step return self.__class__( @@ -131,12 +190,30 @@ class Solver(fl.Module, ABC): ) def scale_model_input(self, x: Tensor, step: int) -> Tensor: - """ - For compatibility with solvers that need to scale the input according to the current timestep. + """Scale the model's input according to the current timestep. + + Note: + This method should only be overridden by solvers that + need to scale the input according to the current timestep. + + Args: + x: The input tensor to scale. + step: The current step of the diffusion process. + + Returns: + The scaled input tensor. """ return x def sample_power_distribution(self, power: float = 2, /) -> Tensor: + """Sample a power distribution. + + Args: + power: The power to use for the distribution. + + Returns: + A tensor representing the power distribution between the initial and final diffusion rates of the solver. + """ return ( linspace( start=self.initial_diffusion_rate ** (1 / power), @@ -147,6 +224,11 @@ class Solver(fl.Module, ABC): ) def sample_noise_schedule(self) -> Tensor: + """Sample the noise schedule. + + Returns: + A tensor representing the noise schedule. + """ match self.noise_schedule: case "uniform": return 1 - self.sample_power_distribution(1) @@ -158,6 +240,15 @@ class Solver(fl.Module, ABC): raise ValueError(f"Unknown noise schedule: {self.noise_schedule}") def to(self, device: Device | str | None = None, dtype: DType | None = None) -> "Solver": + """Move the solver to the specified device and data type. + + Args: + device: The PyTorch device to move the solver to. + dtype: The PyTorch data type to move the solver to. + + Returns: + The solver instance, moved to the specified device and data type. + """ super().to(device=device, dtype=dtype) for name, attr in [(name, attr) for name, attr in self.__dict__.items() if isinstance(attr, Tensor)]: match name: