add rebuild() to Scheduler interface

for use in `set_inference_steps()`
This commit is contained in:
Pierre Chapuis 2024-01-23 09:37:57 +01:00
parent a5c665462a
commit fb2f0e28d4
3 changed files with 27 additions and 9 deletions

View file

@ -33,15 +33,7 @@ class LatentDiffusionModel(fl.Module, ABC):
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
initial_diffusion_rate = self.scheduler.initial_diffusion_rate self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
final_diffusion_rate = self.scheduler.final_diffusion_rate
device, dtype = self.scheduler.device, self.scheduler.dtype
self.scheduler = self.scheduler.__class__(
num_inference_steps=num_steps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
first_inference_step=first_step,
).to(device=device, dtype=dtype)
def init_latents( def init_latents(
self, self,

View file

@ -51,6 +51,18 @@ class DPMSolver(Scheduler):
device=self.device, device=self.device,
).flip(0) ).flip(0)
def rebuild(
self: "DPMSolver",
num_inference_steps: int | None,
first_inference_step: int | None = None,
) -> "DPMSolver":
r = super().rebuild(
num_inference_steps=num_inference_steps,
first_inference_step=first_inference_step,
)
r.last_step_first_order = self.last_step_first_order
return r
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])

View file

@ -77,6 +77,20 @@ class Scheduler(ABC):
def inference_steps(self) -> list[int]: def inference_steps(self) -> list[int]:
return self.all_steps[self.first_inference_step :] return self.all_steps[self.first_inference_step :]
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
return self.__class__(
num_inference_steps=num_inference_steps,
num_train_timesteps=self.num_train_timesteps,
initial_diffusion_rate=self.initial_diffusion_rate,
final_diffusion_rate=self.final_diffusion_rate,
noise_schedule=self.noise_schedule,
first_inference_step=first_inference_step,
device=self.device,
dtype=self.dtype,
)
def scale_model_input(self, x: Tensor, step: int) -> Tensor: def scale_model_input(self, x: Tensor, step: int) -> Tensor:
""" """
For compatibility with schedulers that need to scale the input according to the current timestep. For compatibility with schedulers that need to scale the input according to the current timestep.