diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 44ec1cc..c92e125 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -33,15 +33,7 @@ class LatentDiffusionModel(fl.Module, ABC): self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: - initial_diffusion_rate = self.scheduler.initial_diffusion_rate - final_diffusion_rate = self.scheduler.final_diffusion_rate - device, dtype = self.scheduler.device, self.scheduler.dtype - self.scheduler = self.scheduler.__class__( - num_inference_steps=num_steps, - initial_diffusion_rate=initial_diffusion_rate, - final_diffusion_rate=final_diffusion_rate, - first_inference_step=first_step, - ).to(device=device, dtype=dtype) + self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step) def init_latents( self, diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index 0711fd9..5a9c93c 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -51,6 +51,18 @@ class DPMSolver(Scheduler): device=self.device, ).flip(0) + def rebuild( + self: "DPMSolver", + num_inference_steps: int | None, + first_inference_step: int | None = None, + ) -> "DPMSolver": + r = super().rebuild( + num_inference_steps=num_inference_steps, + first_inference_step=first_inference_step, + ) + r.last_step_first_order = self.last_step_first_order + return r + def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: current_timestep = self.timesteps[step] previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index 37f9beb..acd6425 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -77,6 +77,20 @@ class Scheduler(ABC): def inference_steps(self) -> list[int]: return self.all_steps[self.first_inference_step :] + def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T: + num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps + first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step + return self.__class__( + num_inference_steps=num_inference_steps, + num_train_timesteps=self.num_train_timesteps, + initial_diffusion_rate=self.initial_diffusion_rate, + final_diffusion_rate=self.final_diffusion_rate, + noise_schedule=self.noise_schedule, + first_inference_step=first_inference_step, + device=self.device, + dtype=self.dtype, + ) + def scale_model_input(self, x: Tensor, step: int) -> Tensor: """ For compatibility with schedulers that need to scale the input according to the current timestep.