mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
improve DPM solver test
This commit is contained in:
parent
999e429697
commit
ce3035923b
|
@ -18,14 +18,21 @@ def test_ddpm_diffusers():
|
||||||
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
|
assert equal(diffusers_scheduler.timesteps, refiners_scheduler.timesteps)
|
||||||
|
|
||||||
|
|
||||||
def test_dpm_solver_diffusers():
|
@pytest.mark.parametrize("n_steps, last_step_first_order", [(5, False), (5, True), (30, False), (30, True)])
|
||||||
|
def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
||||||
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
from diffusers import DPMSolverMultistepScheduler as DiffuserScheduler # type: ignore
|
||||||
|
|
||||||
manual_seed(0)
|
manual_seed(0)
|
||||||
|
|
||||||
diffusers_scheduler = DiffuserScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
|
diffusers_scheduler = DiffuserScheduler(
|
||||||
diffusers_scheduler.set_timesteps(30)
|
beta_schedule="scaled_linear",
|
||||||
refiners_scheduler = DPMSolver(num_inference_steps=30)
|
beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
lower_order_final=False,
|
||||||
|
euler_at_final=last_step_first_order,
|
||||||
|
)
|
||||||
|
diffusers_scheduler.set_timesteps(n_steps)
|
||||||
|
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
|
||||||
|
|
||||||
sample = randn(1, 3, 32, 32)
|
sample = randn(1, 3, 32, 32)
|
||||||
noise = randn(1, 3, 32, 32)
|
noise = randn(1, 3, 32, 32)
|
||||||
|
|
Loading…
Reference in a new issue