mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
improve DPM solver test
This commit is contained in:
parent
999e429697
commit
ce3035923b
|
@ -18,14 +18,21 @@ def test_ddpm_diffusers():
|
|||
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
|
||||
|
||||
|
||||
def test_dpm_solver_diffusers():
|
||||
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
|
||||
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
||||
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
||||
|
||||
manual_seed(0)
|
||||
|
||||
diffusers_scheduler = DiffuserScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
|
||||
diffusers_scheduler.set_timesteps(30)
|
||||
refiners_scheduler = DPMSolver(num_inference_steps=30)
|
||||
diffusers_scheduler = DiffuserScheduler(
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
lower_order_final=False,
|
||||
euler_at_final=last_step_first_order,
|
||||
)
|
||||
diffusers_scheduler.set_timesteps(n_steps)
|
||||
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
|
||||
|
||||
sample = randn(1, 3, 32, 32)
|
||||
noise = randn(1, 3, 32, 32)
|
||||
|
|
Loading…
Reference in a new issue