diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 8b2e211..40e3dbe 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -832,7 +832,7 @@ def test_diffusion_std_random_init_bfloat16( ) predicted_image = sd15.lda.latents_to_image(x) - ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16) + ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16, min_psnr=30, min_ssim=0.97) @no_grad() @@ -1166,7 +1166,7 @@ def test_diffusion_inpainting_float16( predicted_image = sd15.lda.latents_to_image(x) # PSNR and SSIM values are large because float16 is even worse than float32. - ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92) + ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95, min_dinov2=0.96) @no_grad() @@ -1245,7 +1245,7 @@ def test_diffusion_controlnet_tile_upscale( predicted_image = sd15.lda.latents_to_image(x) # Note: rather large tolerances are used on purpose here (loose comparison with diffusers' output) - ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75) + ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75, min_dinov2=0.94) @no_grad() @@ -1852,7 +1852,7 @@ def test_diffusion_ella_adapter( condition_scale=12, ) predicted_image = sd15.lda.latents_to_image(x) - ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=35, min_ssim=0.98) + ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=31, min_ssim=0.98) @no_grad() @@ -1937,7 +1937,7 @@ def test_diffusion_ip_adapter_multi( ) predicted_image = sd15.lda.decode_latents(x) - ensure_similar_images(predicted_image, expected_image_ip_adapter_multi) + ensure_similar_images(predicted_image, expected_image_ip_adapter_multi, min_psnr=43, min_ssim=0.98) @no_grad() @@ -2130,7 +2130,7 @@ def test_diffusion_sdxl_ip_adapter_plus( sdxl.lda.to(dtype=torch.float32) predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32)) - ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman) + ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman, min_psnr=43, min_ssim=0.98) @no_grad() @@ -2608,11 +2608,11 @@ def test_style_aligned( # tile all images horizontally merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024)) - for i in range(len(predicted_images)): - merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore + for i, image in enumerate(predicted_images): + merged_image.paste(image, (1024 * i, 0)) # compare against reference image - ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99) + ensure_similar_images(merged_image, expected_style_aligned, min_psnr=12, min_ssim=0.39, min_dinov2=0.95) @no_grad() @@ -2624,7 +2624,7 @@ def test_multi_upscaler( generator = torch.Generator(device=multi_upscaler.device) generator.manual_seed(37) predicted_image = multi_upscaler.upscale(clarity_example, generator=generator) - ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99) + ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=25, min_ssim=0.85, min_dinov2=0.96) @no_grad() diff --git a/tests/e2e/test_doc_examples.py b/tests/e2e/test_doc_examples.py index 0289b9f..e0f5cf8 100644 --- a/tests/e2e/test_doc_examples.py +++ b/tests/e2e/test_doc_examples.py @@ -110,7 +110,7 @@ def test_guide_adapting_sdxl_vanilla( ) predicted_image = sdxl.lda.decode_latents(x) - ensure_similar_images(predicted_image, expected_image) + ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @no_grad() @@ -152,7 +152,7 @@ def test_guide_adapting_sdxl_single_lora( ) predicted_image = sdxl.lda.decode_latents(x) - ensure_similar_images(predicted_image, expected_image) + ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98) @no_grad() @@ -196,7 +196,7 @@ def test_guide_adapting_sdxl_multiple_loras( ) predicted_image = sdxl.lda.decode_latents(x) - ensure_similar_images(predicted_image, expected_image) + ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98) @no_grad() @@ -256,7 +256,7 @@ def test_guide_adapting_sdxl_loras_ip_adapter( ) predicted_image = sdxl.lda.decode_latents(x) - ensure_similar_images(predicted_image, expected_image) + ensure_similar_images(predicted_image, expected_image, min_psnr=29, min_ssim=0.98) # We do not (yet) test the last example using T2i-Adapter with Zoe Depth. diff --git a/tests/e2e/test_lightning.py b/tests/e2e/test_lightning.py index 83d701e..0acf358 100644 --- a/tests/e2e/test_lightning.py +++ b/tests/e2e/test_lightning.py @@ -93,7 +93,7 @@ def test_lightning_base_4step( ) predicted_image = sdxl.lda.latents_to_image(x) - ensure_similar_images(predicted_image, expected_image) + ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98) @no_grad() @@ -144,7 +144,7 @@ def test_lightning_base_1step( ) predicted_image = sdxl.lda.latents_to_image(x) - ensure_similar_images(predicted_image, expected_image) + ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98) @no_grad() @@ -198,4 +198,4 @@ def test_lightning_lora_4step( ) predicted_image = sdxl.lda.latents_to_image(x) - ensure_similar_images(predicted_image, expected_image) + ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98) diff --git a/tests/foundationals/segment_anything/test_hq_sam.py b/tests/foundationals/segment_anything/test_hq_sam.py index 439bd26..9c5bb3c 100644 --- a/tests/foundationals/segment_anything/test_hq_sam.py +++ b/tests/foundationals/segment_anything/test_hq_sam.py @@ -238,15 +238,15 @@ def test_predictor( assert torch.allclose( reference_low_res_mask_hq, refiners_low_res_mask_hq, - atol=4e-3, + atol=1e-2, + ) + assert ( # absolute diff in number of pixels + torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 10 ) - assert ( - torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 2 - ) # The diff on the logits above leads to an absolute diff of 2 pixel on the high res masks assert torch.allclose( iou_predictions_np, torch.max(iou_predictions), - atol=1e-5, + atol=1e-4, ) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 95eed70..96ab04c 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -317,8 +317,9 @@ def test_predictor( for i in range(3): mask_prediction = masks[i].cpu() facebook_mask = torch.as_tensor(facebook_masks[i]) - assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05) - assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05) + iou = intersection_over_union(mask_prediction, facebook_mask) + assert isclose(iou, 1.0, rel_tol=5e-04), f"iou: {iou}" + assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-04) def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None: diff --git a/tests/utils.py b/tests/utils.py index 963b44c..f1e99af 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,25 +1,65 @@ +from functools import cache from pathlib import Path +from textwrap import dedent -import numpy as np import piq # type: ignore import torch import torch.nn as nn from PIL import Image from transformers import T5EncoderModel, T5Tokenizer # type: ignore +from refiners.conversion.models import dinov2 +from refiners.fluxion.utils import image_to_tensor +from refiners.foundationals.dinov2 import DINOv2_small -def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]: - x1, x2 = ( - torch.tensor(np.array(x).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0 for x in (img_1, img_2) + +@cache +def get_small_dinov2_model() -> DINOv2_small: + model = DINOv2_small() + model.load_from_safetensors( + dinov2.small.converted.local_path + if dinov2.small.converted.local_path.exists() + else dinov2.small.converted.hf_cache_path ) - return (piq.psnr(x1, x2), piq.ssim(x1, x2).item()) # type: ignore + return model -def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int = 45, min_ssim: float = 0.99): - psnr, ssim = compare_images(img_1, img_2) - assert (psnr >= min_psnr) and ( - ssim >= min_ssim - ), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}" +def compare_images( + img_1: Image.Image, + img_2: Image.Image, +) -> tuple[float, float, float]: + x1 = image_to_tensor(img_1) + x2 = image_to_tensor(img_2) + + psnr = piq.psnr(x1, x2) # type: ignore + ssim = piq.ssim(x1, x2) # type: ignore + + dinov2_model = get_small_dinov2_model() + dinov2 = torch.nn.functional.cosine_similarity( + dinov2_model(x1)[:, 0], + dinov2_model(x2)[:, 0], + ) + + return psnr.item(), ssim.item(), dinov2.item() # type: ignore + + +def ensure_similar_images( + img_1: Image.Image, + img_2: Image.Image, + min_psnr: int = 45, + min_ssim: float = 0.99, + min_dinov2: float = 0.99, +) -> None: + psnr, ssim, dinov2 = compare_images(img_1, img_2) + if (psnr < min_psnr) or (ssim < min_ssim) or (dinov2 < min_dinov2): + raise AssertionError( + dedent(f""" + Images are not similar enough! + - PSNR: {psnr:08.05f} (required at least {min_psnr:08.05f}) + - SSIM: {ssim:08.06f} (required at least {min_ssim:08.06f}) + - DINO: {dinov2:08.06f} (required at least {min_dinov2:08.06f}) + """).strip() + ) class T5TextEmbedder(nn.Module):