mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add rebuild()
to Scheduler interface
for use in `set_inference_steps()`
This commit is contained in:
parent
a5c665462a
commit
fb2f0e28d4
|
@ -33,15 +33,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
||||
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
|
||||
final_diffusion_rate = self.scheduler.final_diffusion_rate
|
||||
device, dtype = self.scheduler.device, self.scheduler.dtype
|
||||
self.scheduler = self.scheduler.__class__(
|
||||
num_inference_steps=num_steps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
first_inference_step=first_step,
|
||||
).to(device=device, dtype=dtype)
|
||||
self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
||||
|
||||
def init_latents(
|
||||
self,
|
||||
|
|
|
@ -51,6 +51,18 @@ class DPMSolver(Scheduler):
|
|||
device=self.device,
|
||||
).flip(0)
|
||||
|
||||
def rebuild(
|
||||
self: "DPMSolver",
|
||||
num_inference_steps: int | None,
|
||||
first_inference_step: int | None = None,
|
||||
) -> "DPMSolver":
|
||||
r = super().rebuild(
|
||||
num_inference_steps=num_inference_steps,
|
||||
first_inference_step=first_inference_step,
|
||||
)
|
||||
r.last_step_first_order = self.last_step_first_order
|
||||
return r
|
||||
|
||||
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
current_timestep = self.timesteps[step]
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||
|
|
|
@ -77,6 +77,20 @@ class Scheduler(ABC):
|
|||
def inference_steps(self) -> list[int]:
|
||||
return self.all_steps[self.first_inference_step :]
|
||||
|
||||
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
|
||||
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
||||
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
||||
return self.__class__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
initial_diffusion_rate=self.initial_diffusion_rate,
|
||||
final_diffusion_rate=self.final_diffusion_rate,
|
||||
noise_schedule=self.noise_schedule,
|
||||
first_inference_step=first_inference_step,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
"""
|
||||
For compatibility with schedulers that need to scale the input according to the current timestep.
|
||||
|
|
Loading…
Reference in a new issue