add DINOv2 similarity to compare_images + relax some test constraints

This commit is contained in:
Laurent 2024-10-14 11:50:56 +00:00 committed by Laureηt
parent d498ecd369
commit ef0a7525a8
6 changed files with 75 additions and 34 deletions

View file

@ -832,7 +832,7 @@ def test_diffusion_std_random_init_bfloat16(
) )
predicted_image = sd15.lda.latents_to_image(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16) ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16, min_psnr=30, min_ssim=0.97)
@no_grad() @no_grad()
@ -1166,7 +1166,7 @@ def test_diffusion_inpainting_float16(
predicted_image = sd15.lda.latents_to_image(x) predicted_image = sd15.lda.latents_to_image(x)
# PSNR and SSIM values are large because float16 is even worse than float32. # PSNR and SSIM values are large because float16 is even worse than float32.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92) ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95, min_dinov2=0.96)
@no_grad() @no_grad()
@ -1245,7 +1245,7 @@ def test_diffusion_controlnet_tile_upscale(
predicted_image = sd15.lda.latents_to_image(x) predicted_image = sd15.lda.latents_to_image(x)
# Note: rather large tolerances are used on purpose here (loose comparison with diffusers' output) # Note: rather large tolerances are used on purpose here (loose comparison with diffusers' output)
ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75) ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75, min_dinov2=0.94)
@no_grad() @no_grad()
@ -1852,7 +1852,7 @@ def test_diffusion_ella_adapter(
condition_scale=12, condition_scale=12,
) )
predicted_image = sd15.lda.latents_to_image(x) predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=31, min_ssim=0.98)
@no_grad() @no_grad()
@ -1937,7 +1937,7 @@ def test_diffusion_ip_adapter_multi(
) )
predicted_image = sd15.lda.decode_latents(x) predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi) ensure_similar_images(predicted_image, expected_image_ip_adapter_multi, min_psnr=43, min_ssim=0.98)
@no_grad() @no_grad()
@ -2130,7 +2130,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
sdxl.lda.to(dtype=torch.float32) sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32)) predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman) ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman, min_psnr=43, min_ssim=0.98)
@no_grad() @no_grad()
@ -2608,11 +2608,11 @@ def test_style_aligned(
# tile all images horizontally # tile all images horizontally
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024)) merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
for i in range(len(predicted_images)): for i, image in enumerate(predicted_images):
merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore merged_image.paste(image, (1024 * i, 0))
# compare against reference image # compare against reference image
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99) ensure_similar_images(merged_image, expected_style_aligned, min_psnr=12, min_ssim=0.39, min_dinov2=0.95)
@no_grad() @no_grad()
@ -2624,7 +2624,7 @@ def test_multi_upscaler(
generator = torch.Generator(device=multi_upscaler.device) generator = torch.Generator(device=multi_upscaler.device)
generator.manual_seed(37) generator.manual_seed(37)
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator) predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99) ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=25, min_ssim=0.85, min_dinov2=0.96)
@no_grad() @no_grad()

View file

@ -110,7 +110,7 @@ def test_guide_adapting_sdxl_vanilla(
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad() @no_grad()
@ -152,7 +152,7 @@ def test_guide_adapting_sdxl_single_lora(
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
@no_grad() @no_grad()
@ -196,7 +196,7 @@ def test_guide_adapting_sdxl_multiple_loras(
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
@no_grad() @no_grad()
@ -256,7 +256,7 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
) )
predicted_image = sdxl.lda.decode_latents(x) predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image, min_psnr=29, min_ssim=0.98)
# We do not (yet) test the last example using T2i-Adapter with Zoe Depth. # We do not (yet) test the last example using T2i-Adapter with Zoe Depth.

View file

@ -93,7 +93,7 @@ def test_lightning_base_4step(
) )
predicted_image = sdxl.lda.latents_to_image(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)
@no_grad() @no_grad()
@ -144,7 +144,7 @@ def test_lightning_base_1step(
) )
predicted_image = sdxl.lda.latents_to_image(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)
@no_grad() @no_grad()
@ -198,4 +198,4 @@ def test_lightning_lora_4step(
) )
predicted_image = sdxl.lda.latents_to_image(x) predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)

View file

@ -238,15 +238,15 @@ def test_predictor(
assert torch.allclose( assert torch.allclose(
reference_low_res_mask_hq, reference_low_res_mask_hq,
refiners_low_res_mask_hq, refiners_low_res_mask_hq,
atol=4e-3, atol=1e-2,
)
assert ( # absolute diff in number of pixels
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 10
) )
assert (
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 2
) # The diff on the logits above leads to an absolute diff of 2 pixel on the high res masks
assert torch.allclose( assert torch.allclose(
iou_predictions_np, iou_predictions_np,
torch.max(iou_predictions), torch.max(iou_predictions),
atol=1e-5, atol=1e-4,
) )

View file

@ -317,8 +317,9 @@ def test_predictor(
for i in range(3): for i in range(3):
mask_prediction = masks[i].cpu() mask_prediction = masks[i].cpu()
facebook_mask = torch.as_tensor(facebook_masks[i]) facebook_mask = torch.as_tensor(facebook_masks[i])
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05) iou = intersection_over_union(mask_prediction, facebook_mask)
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05) assert isclose(iou, 1.0, rel_tol=5e-04), f"iou: {iou}"
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-04)
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None: def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:

View file

@ -1,25 +1,65 @@
from functools import cache
from pathlib import Path from pathlib import Path
from textwrap import dedent
import numpy as np
import piq # type: ignore import piq # type: ignore
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer # type: ignore from transformers import T5EncoderModel, T5Tokenizer # type: ignore
from refiners.conversion.models import dinov2
from refiners.fluxion.utils import image_to_tensor
from refiners.foundationals.dinov2 import DINOv2_small
def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]:
x1, x2 = ( @cache
torch.tensor(np.array(x).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0 for x in (img_1, img_2) def get_small_dinov2_model() -> DINOv2_small:
model = DINOv2_small()
model.load_from_safetensors(
dinov2.small.converted.local_path
if dinov2.small.converted.local_path.exists()
else dinov2.small.converted.hf_cache_path
) )
return (piq.psnr(x1, x2), piq.ssim(x1, x2).item()) # type: ignore return model
def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int = 45, min_ssim: float = 0.99): def compare_images(
psnr, ssim = compare_images(img_1, img_2) img_1: Image.Image,
assert (psnr >= min_psnr) and ( img_2: Image.Image,
ssim >= min_ssim ) -> tuple[float, float, float]:
), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}" x1 = image_to_tensor(img_1)
x2 = image_to_tensor(img_2)
psnr = piq.psnr(x1, x2) # type: ignore
ssim = piq.ssim(x1, x2) # type: ignore
dinov2_model = get_small_dinov2_model()
dinov2 = torch.nn.functional.cosine_similarity(
dinov2_model(x1)[:, 0],
dinov2_model(x2)[:, 0],
)
return psnr.item(), ssim.item(), dinov2.item() # type: ignore
def ensure_similar_images(
img_1: Image.Image,
img_2: Image.Image,
min_psnr: int = 45,
min_ssim: float = 0.99,
min_dinov2: float = 0.99,
) -> None:
psnr, ssim, dinov2 = compare_images(img_1, img_2)
if (psnr < min_psnr) or (ssim < min_ssim) or (dinov2 < min_dinov2):
raise AssertionError(
dedent(f"""
Images are not similar enough!
- PSNR: {psnr:08.05f} (required at least {min_psnr:08.05f})
- SSIM: {ssim:08.06f} (required at least {min_ssim:08.06f})
- DINO: {dinov2:08.06f} (required at least {min_dinov2:08.06f})
""").strip()
)
class T5TextEmbedder(nn.Module): class T5TextEmbedder(nn.Module):