mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add a test for SDXL + EulerScheduler (deterministic)
This commit is contained in:
parent
5ac5373310
commit
8a2b955bd0
|
@ -135,12 +135,17 @@ def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert(mode="RGB")
|
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
|
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert(mode="RGB")
|
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||||
|
@ -627,6 +632,18 @@ def sdxl_ddim_lda_fp16_fix(
|
||||||
return sdxl
|
return sdxl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_XL:
|
||||||
|
return StableDiffusion_XL(
|
||||||
|
unet=sdxl_ddim.unet,
|
||||||
|
lda=sdxl_ddim.lda,
|
||||||
|
clip_text_encoder=sdxl_ddim.clip_text_encoder,
|
||||||
|
solver=Euler(num_inference_steps=30),
|
||||||
|
device=sdxl_ddim.device,
|
||||||
|
dtype=sdxl_ddim.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_std_random_init(
|
def test_diffusion_std_random_init(
|
||||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||||
|
@ -1684,6 +1701,44 @@ def test_diffusion_sdxl_sliced_attention(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_diffusion_sdxl_euler_deterministic(
|
||||||
|
sdxl_euler_deterministic: StableDiffusion_XL, expected_sdxl_euler_random_init: Image.Image
|
||||||
|
) -> None:
|
||||||
|
sdxl = sdxl_euler_deterministic
|
||||||
|
assert isinstance(sdxl.solver, Euler)
|
||||||
|
|
||||||
|
expected_image = expected_sdxl_euler_random_init
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
|
text=prompt, negative_text=negative_prompt
|
||||||
|
)
|
||||||
|
time_ids = sdxl.default_time_ids
|
||||||
|
sdxl.set_inference_steps(30)
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype)
|
||||||
|
|
||||||
|
# init latents must be scaled for Euler
|
||||||
|
# TODO make init_latents work
|
||||||
|
x = x * sdxl.solver.init_noise_sigma
|
||||||
|
|
||||||
|
for step in sdxl.steps:
|
||||||
|
x = sdxl(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted_image = sdxl.lda.decode_latents(x)
|
||||||
|
ensure_similar_images(predicted_image, expected_image)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
||||||
manual_seed(seed=2)
|
manual_seed(seed=2)
|
||||||
|
|
|
@ -45,6 +45,7 @@ Special cases:
|
||||||
- `expected_t2i_adapter_xl_canny.png`
|
- `expected_t2i_adapter_xl_canny.png`
|
||||||
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
||||||
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
||||||
|
- `expected_cutecat_sdxl_euler_random_init.png`
|
||||||
- `expected_restart.png`
|
- `expected_restart.png`
|
||||||
- `expected_freeu.png`
|
- `expected_freeu.png`
|
||||||
- `expected_dropy_slime_9752.png`
|
- `expected_dropy_slime_9752.png`
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 1.6 MiB |
Loading…
Reference in a new issue