diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 4bb51e8..11d8450 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -33,6 +33,15 @@ class LatentDiffusionModel(fl.Module, ABC): self.classifier_free_guidance = classifier_free_guidance def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: + """Set the steps of the diffusion process. + + Args: + num_steps: The number of inference steps. + first_step: The first inference step, used for image-to-image diffusion. + You may be used to setting a float in `[0, 1]` called `strength` instead, + which is an abstraction for this. The first step is + `round((1 - strength) * (num_steps - 1))`. + """ self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step) @staticmethod