diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 80bd583..1b1279e 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -25,11 +25,19 @@ from refiners.foundationals.latent_diffusion import ( StableDiffusion_1_Inpainting, ) from refiners.foundationals.latent_diffusion.lora import SDLoraManager -from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget +from refiners.foundationals.latent_diffusion.multi_diffusion import Size, Tile from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams -from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion +from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver +from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import ( + SD1DiffusionTarget, + SD1MultiDiffusion, +) +from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( + MultiUpscaler, + UpscalerCheckpoints, +) from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter @@ -406,6 +414,11 @@ def expected_multi_diffusion(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB") +@pytest.fixture +def expected_multi_diffusion_dpm(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_multi_diffusion_dpm.png").convert(mode="RGB") + + @pytest.fixture def expected_restart(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_restart.png").convert(mode="RGB") @@ -756,6 +769,41 @@ def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_X ) +@pytest.fixture(scope="module") +def multi_upscaler( + test_weights_path: Path, + unet_weights_std: Path, + text_encoder_weights: Path, + lda_ft_mse_weights: Path, + test_device: torch.device, +) -> MultiUpscaler: + controlnet_tile_weights = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors" + if not controlnet_tile_weights.is_file(): + warn(message=f"could not find weights at {controlnet_tile_weights}, skipping") + pytest.skip(allow_module_level=True) + + return MultiUpscaler( + checkpoints=UpscalerCheckpoints( + unet=unet_weights_std, + clip_text_encoder=text_encoder_weights, + lda=lda_ft_mse_weights, + controlnet_tile=controlnet_tile_weights, + ), + device=test_device, + dtype=torch.float32, + ) + + +@pytest.fixture(scope="module") +def clarity_example(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "clarity_input_example.png") + + +@pytest.fixture(scope="module") +def expected_multi_upscaler(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_multi_upscaler.png") + + @no_grad() def test_diffusion_std_random_init( sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device @@ -2132,15 +2180,15 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: sd = sd15_ddim multi_diffusion = SD1MultiDiffusion(sd) clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain") - target_1 = DiffusionTarget( - size=(64, 64), - offset=(0, 0), + # DDIM doesn't have an internal state, so we can share the same solver for all targets + target_1 = SD1DiffusionTarget( + tile=Tile(top=0, left=0, bottom=64, right=64), + solver=sd.solver, clip_text_embedding=clip_text_embedding, - start_step=0, ) - target_2 = DiffusionTarget( - size=(64, 64), - offset=(0, 16), + target_2 = SD1DiffusionTarget( + solver=sd.solver, + tile=Tile(top=0, left=16, bottom=64, right=80), clip_text_embedding=clip_text_embedding, condition_scale=3, start_step=0, @@ -2158,6 +2206,35 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98) +@no_grad() +def test_multi_diffusion_dpm(sd15_std: StableDiffusion_1, expected_multi_diffusion_dpm: Image.Image) -> None: + manual_seed(seed=2) + sd = sd15_std + multi_diffusion = SD1MultiDiffusion(sd) + clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain") + tiles = SD1MultiDiffusion.generate_latent_tiles(size=Size(112, 196), tile_size=Size(96, 64), min_overlap=12) + targets = [ + SD1DiffusionTarget( + tile=tile, + solver=DPMSolver(num_inference_steps=sd.solver.num_inference_steps, device=sd.device), + clip_text_embedding=clip_text_embedding, + ) + for tile in tiles + ] + + noise = torch.randn(1, 4, 112, 196, device=sd.device, dtype=sd.dtype) + x = noise + for step in sd.steps: + x = multi_diffusion( + x, + noise=noise, + step=step, + targets=targets, + ) + result = sd.lda.latents_to_image(x=x) + ensure_similar_images(img_1=result, img_2=expected_multi_diffusion_dpm, min_psnr=35, min_ssim=0.98) + + @no_grad() def test_t2i_adapter_depth( sd15_std: StableDiffusion_1, @@ -2427,3 +2504,13 @@ def test_style_aligned( # compare against reference image ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99) + + +@no_grad() +def test_multi_upscaler( + multi_upscaler: MultiUpscaler, + clarity_example: Image.Image, + expected_multi_upscaler: Image.Image, +) -> None: + predicted_image = multi_upscaler.upscale(clarity_example) + ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99) diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 73cea4c..9e7075a 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -58,6 +58,8 @@ Special cases: - `expected_controllora_disabled.png` - `expected_style_aligned.png` - `expected_controlnet_canny_scale_decay.png` + - `expected_multi_diffusion_dpm.png` + - `expected_multi_upscaler.png` ## Other images @@ -92,6 +94,8 @@ Special cases: - `low_res_dog.png` and `expected_controlnet_tile.png` are taken from Diffusers [documentation](https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/tree/main/images), respectively named `original.png` and `output.png`. +- `clarity_input_example.png` is taken from the [Replicate demo](https://replicate.com/philz1337x/clarity-upscaler/examples) of the Clarity upscaler. + ## VAE without randomness ```diff diff --git a/tests/e2e/test_diffusion_ref/clarity_input_example.png b/tests/e2e/test_diffusion_ref/clarity_input_example.png new file mode 100644 index 0000000..8476fd5 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/clarity_input_example.png differ diff --git a/tests/e2e/test_diffusion_ref/expected_multi_diffusion_dpm.png b/tests/e2e/test_diffusion_ref/expected_multi_diffusion_dpm.png new file mode 100644 index 0000000..cfea1a3 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_multi_diffusion_dpm.png differ diff --git a/tests/e2e/test_diffusion_ref/expected_multi_upscaler.png b/tests/e2e/test_diffusion_ref/expected_multi_upscaler.png new file mode 100644 index 0000000..820dad1 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_multi_upscaler.png differ diff --git a/tests/foundationals/latent_diffusion/test_multi_diffusion.py b/tests/foundationals/latent_diffusion/test_multi_diffusion.py new file mode 100644 index 0000000..8ca13f0 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_multi_diffusion.py @@ -0,0 +1,34 @@ +import pytest + +from refiners.foundationals.latent_diffusion.multi_diffusion import MultiDiffusion, Size + + +def test_generate_latent_tiles() -> None: + size = Size(height=128, width=128) + tile_size = Size(height=32, width=32) + tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size) + assert len(tiles) == 25 + + tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size, min_overlap=0) + assert len(tiles) == 16 + + size = Size(height=100, width=200) + tile_size = Size(height=32, width=32) + tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size, min_overlap=2) + assert len(tiles) == 28 + + +def test_generate_latent_tiles_small_size() -> None: + # Test when the size is smaller than the tile size + size = Size(height=32, width=32) + tile_size = Size(height=64, width=64) + tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size) + assert len(tiles) == 1 + assert Size(tiles[0].bottom - tiles[0].top, tiles[0].right - tiles[0].left) == size + + +def test_overlap_larger_tile_size() -> None: + with pytest.raises(AssertionError): + size = Size(height=128, width=128) + tile_size = Size(height=32, width=32) + MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size, min_overlap=32)