mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
add a test for SDXL + EulerScheduler (deterministic)
This commit is contained in:
parent
5ac5373310
commit
8a2b955bd0
|
@ -135,12 +135,17 @@ def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
|
|||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
||||
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert(mode="RGB")
|
||||
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||
|
@ -627,6 +632,18 @@ def sdxl_ddim_lda_fp16_fix(
|
|||
return sdxl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_XL:
|
||||
return StableDiffusion_XL(
|
||||
unet=sdxl_ddim.unet,
|
||||
lda=sdxl_ddim.lda,
|
||||
clip_text_encoder=sdxl_ddim.clip_text_encoder,
|
||||
solver=Euler(num_inference_steps=30),
|
||||
device=sdxl_ddim.device,
|
||||
dtype=sdxl_ddim.dtype,
|
||||
)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_std_random_init(
|
||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||
|
@ -1684,6 +1701,44 @@ def test_diffusion_sdxl_sliced_attention(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_sdxl_euler_deterministic(
|
||||
sdxl_euler_deterministic: StableDiffusion_XL, expected_sdxl_euler_random_init: Image.Image
|
||||
) -> None:
|
||||
sdxl = sdxl_euler_deterministic
|
||||
assert isinstance(sdxl.solver, Euler)
|
||||
|
||||
expected_image = expected_sdxl_euler_random_init
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt, negative_text=negative_prompt
|
||||
)
|
||||
time_ids = sdxl.default_time_ids
|
||||
sdxl.set_inference_steps(30)
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||
|
||||
# init latents must be scaled for Euler
|
||||
# TODO make init_latents work
|
||||
x = x * sdxl.solver.init_noise_sigma
|
||||
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
pooled_text_embedding=pooled_text_embedding,
|
||||
time_ids=time_ids,
|
||||
condition_scale=5,
|
||||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
||||
manual_seed(seed=2)
|
||||
|
|
|
@ -45,6 +45,7 @@ Special cases:
|
|||
- `expected_t2i_adapter_xl_canny.png`
|
||||
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
||||
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
||||
- `expected_cutecat_sdxl_euler_random_init.png`
|
||||
- `expected_restart.png`
|
||||
- `expected_freeu.png`
|
||||
- `expected_dropy_slime_9752.png`
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 1.6 MiB |
Loading…
Reference in a new issue