mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add rebuild()
to Scheduler interface
for use in `set_inference_steps()`
This commit is contained in:
parent
a5c665462a
commit
fb2f0e28d4
|
@ -33,15 +33,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
||||||
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
|
self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
||||||
final_diffusion_rate = self.scheduler.final_diffusion_rate
|
|
||||||
device, dtype = self.scheduler.device, self.scheduler.dtype
|
|
||||||
self.scheduler = self.scheduler.__class__(
|
|
||||||
num_inference_steps=num_steps,
|
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
|
||||||
first_inference_step=first_step,
|
|
||||||
).to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def init_latents(
|
def init_latents(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -51,6 +51,18 @@ class DPMSolver(Scheduler):
|
||||||
device=self.device,
|
device=self.device,
|
||||||
).flip(0)
|
).flip(0)
|
||||||
|
|
||||||
|
def rebuild(
|
||||||
|
self: "DPMSolver",
|
||||||
|
num_inference_steps: int | None,
|
||||||
|
first_inference_step: int | None = None,
|
||||||
|
) -> "DPMSolver":
|
||||||
|
r = super().rebuild(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
first_inference_step=first_inference_step,
|
||||||
|
)
|
||||||
|
r.last_step_first_order = self.last_step_first_order
|
||||||
|
return r
|
||||||
|
|
||||||
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
current_timestep = self.timesteps[step]
|
current_timestep = self.timesteps[step]
|
||||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||||
|
|
|
@ -77,6 +77,20 @@ class Scheduler(ABC):
|
||||||
def inference_steps(self) -> list[int]:
|
def inference_steps(self) -> list[int]:
|
||||||
return self.all_steps[self.first_inference_step :]
|
return self.all_steps[self.first_inference_step :]
|
||||||
|
|
||||||
|
def rebuild(self: T, num_inference_steps: int | None, first_inference_step: int | None = None) -> T:
|
||||||
|
num_inference_steps = self.num_inference_steps if num_inference_steps is None else num_inference_steps
|
||||||
|
first_inference_step = self.first_inference_step if first_inference_step is None else first_inference_step
|
||||||
|
return self.__class__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_train_timesteps=self.num_train_timesteps,
|
||||||
|
initial_diffusion_rate=self.initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=self.final_diffusion_rate,
|
||||||
|
noise_schedule=self.noise_schedule,
|
||||||
|
first_inference_step=first_inference_step,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
For compatibility with schedulers that need to scale the input according to the current timestep.
|
For compatibility with schedulers that need to scale the input according to the current timestep.
|
||||||
|
|
Loading…
Reference in a new issue