add DINOv2 similarity to compare_images + relax some test constraints

This commit is contained in:
Laurent 2024-10-14 11:50:56 +00:00 committed by Laureηt
parent d498ecd369
commit ef0a7525a8
6 changed files with 75 additions and 34 deletions

View file

@ -832,7 +832,7 @@ def test_diffusion_std_random_init_bfloat16(
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16)
ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16, min_psnr=30, min_ssim=0.97)
@no_grad()
@ -1166,7 +1166,7 @@ def test_diffusion_inpainting_float16(
predicted_image = sd15.lda.latents_to_image(x)
# PSNR and SSIM values are large because float16 is even worse than float32.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95, min_dinov2=0.96)
@no_grad()
@ -1245,7 +1245,7 @@ def test_diffusion_controlnet_tile_upscale(
predicted_image = sd15.lda.latents_to_image(x)
# Note: rather large tolerances are used on purpose here (loose comparison with diffusers' output)
ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75)
ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75, min_dinov2=0.94)
@no_grad()
@ -1852,7 +1852,7 @@ def test_diffusion_ella_adapter(
condition_scale=12,
)
predicted_image = sd15.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=35, min_ssim=0.98)
ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=31, min_ssim=0.98)
@no_grad()
@ -1937,7 +1937,7 @@ def test_diffusion_ip_adapter_multi(
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi)
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi, min_psnr=43, min_ssim=0.98)
@no_grad()
@ -2130,7 +2130,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman, min_psnr=43, min_ssim=0.98)
@no_grad()
@ -2608,11 +2608,11 @@ def test_style_aligned(
# tile all images horizontally
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
for i in range(len(predicted_images)):
merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore
for i, image in enumerate(predicted_images):
merged_image.paste(image, (1024 * i, 0))
# compare against reference image
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=12, min_ssim=0.39, min_dinov2=0.95)
@no_grad()
@ -2624,7 +2624,7 @@ def test_multi_upscaler(
generator = torch.Generator(device=multi_upscaler.device)
generator.manual_seed(37)
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=25, min_ssim=0.85, min_dinov2=0.96)
@no_grad()

View file

@ -110,7 +110,7 @@ def test_guide_adapting_sdxl_vanilla(
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@no_grad()
@ -152,7 +152,7 @@ def test_guide_adapting_sdxl_single_lora(
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
@no_grad()
@ -196,7 +196,7 @@ def test_guide_adapting_sdxl_multiple_loras(
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
@no_grad()
@ -256,7 +256,7 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
)
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
ensure_similar_images(predicted_image, expected_image, min_psnr=29, min_ssim=0.98)
# We do not (yet) test the last example using T2i-Adapter with Zoe Depth.

View file

@ -93,7 +93,7 @@ def test_lightning_base_4step(
)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)
ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)
@no_grad()
@ -144,7 +144,7 @@ def test_lightning_base_1step(
)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)
ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)
@no_grad()
@ -198,4 +198,4 @@ def test_lightning_lora_4step(
)
predicted_image = sdxl.lda.latents_to_image(x)
ensure_similar_images(predicted_image, expected_image)
ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)

View file

@ -238,15 +238,15 @@ def test_predictor(
assert torch.allclose(
reference_low_res_mask_hq,
refiners_low_res_mask_hq,
atol=4e-3,
atol=1e-2,
)
assert ( # absolute diff in number of pixels
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 10
)
assert (
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 2
) # The diff on the logits above leads to an absolute diff of 2 pixel on the high res masks
assert torch.allclose(
iou_predictions_np,
torch.max(iou_predictions),
atol=1e-5,
atol=1e-4,
)

View file

@ -317,8 +317,9 @@ def test_predictor(
for i in range(3):
mask_prediction = masks[i].cpu()
facebook_mask = torch.as_tensor(facebook_masks[i])
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)
iou = intersection_over_union(mask_prediction, facebook_mask)
assert isclose(iou, 1.0, rel_tol=5e-04), f"iou: {iou}"
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-04)
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:

View file

@ -1,25 +1,65 @@
from functools import cache
from pathlib import Path
from textwrap import dedent
import numpy as np
import piq # type: ignore
import torch
import torch.nn as nn
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
from refiners.conversion.models import dinov2
from refiners.fluxion.utils import image_to_tensor
from refiners.foundationals.dinov2 import DINOv2_small
def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]:
x1, x2 = (
torch.tensor(np.array(x).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0 for x in (img_1, img_2)
@cache
def get_small_dinov2_model() -> DINOv2_small:
model = DINOv2_small()
model.load_from_safetensors(
dinov2.small.converted.local_path
if dinov2.small.converted.local_path.exists()
else dinov2.small.converted.hf_cache_path
)
return (piq.psnr(x1, x2), piq.ssim(x1, x2).item()) # type: ignore
return model
def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int = 45, min_ssim: float = 0.99):
psnr, ssim = compare_images(img_1, img_2)
assert (psnr >= min_psnr) and (
ssim >= min_ssim
), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}"
def compare_images(
img_1: Image.Image,
img_2: Image.Image,
) -> tuple[float, float, float]:
x1 = image_to_tensor(img_1)
x2 = image_to_tensor(img_2)
psnr = piq.psnr(x1, x2) # type: ignore
ssim = piq.ssim(x1, x2) # type: ignore
dinov2_model = get_small_dinov2_model()
dinov2 = torch.nn.functional.cosine_similarity(
dinov2_model(x1)[:, 0],
dinov2_model(x2)[:, 0],
)
return psnr.item(), ssim.item(), dinov2.item() # type: ignore
def ensure_similar_images(
img_1: Image.Image,
img_2: Image.Image,
min_psnr: int = 45,
min_ssim: float = 0.99,
min_dinov2: float = 0.99,
) -> None:
psnr, ssim, dinov2 = compare_images(img_1, img_2)
if (psnr < min_psnr) or (ssim < min_ssim) or (dinov2 < min_dinov2):
raise AssertionError(
dedent(f"""
Images are not similar enough!
- PSNR: {psnr:08.05f} (required at least {min_psnr:08.05f})
- SSIM: {ssim:08.06f} (required at least {min_ssim:08.06f})
- DINO: {dinov2:08.06f} (required at least {min_dinov2:08.06f})
""").strip()
)
class T5TextEmbedder(nn.Module):