mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 05:38:46 +00:00
add DINOv2 similarity to compare_images + relax some test constraints
This commit is contained in:
parent
d498ecd369
commit
ef0a7525a8
|
@ -832,7 +832,7 @@ def test_diffusion_std_random_init_bfloat16(
|
|||
)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16)
|
||||
ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16, min_psnr=30, min_ssim=0.97)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -1166,7 +1166,7 @@ def test_diffusion_inpainting_float16(
|
|||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
# PSNR and SSIM values are large because float16 is even worse than float32.
|
||||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
|
||||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95, min_dinov2=0.96)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -1245,7 +1245,7 @@ def test_diffusion_controlnet_tile_upscale(
|
|||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
# Note: rather large tolerances are used on purpose here (loose comparison with diffusers' output)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=24, min_ssim=0.75, min_dinov2=0.94)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -1852,7 +1852,7 @@ def test_diffusion_ella_adapter(
|
|||
condition_scale=12,
|
||||
)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=35, min_ssim=0.98)
|
||||
ensure_similar_images(predicted_image, expected_image_ella_adapter, min_psnr=31, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -1937,7 +1937,7 @@ def test_diffusion_ip_adapter_multi(
|
|||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi)
|
||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_multi, min_psnr=43, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -2130,7 +2130,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
|||
sdxl.lda.to(dtype=torch.float32)
|
||||
predicted_image = sdxl.lda.latents_to_image(x.to(dtype=torch.float32))
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
|
||||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman, min_psnr=43, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -2608,11 +2608,11 @@ def test_style_aligned(
|
|||
|
||||
# tile all images horizontally
|
||||
merged_image = Image.new("RGB", (1024 * len(predicted_images), 1024))
|
||||
for i in range(len(predicted_images)):
|
||||
merged_image.paste(predicted_images[i], (i * 1024, 0)) # type: ignore
|
||||
for i, image in enumerate(predicted_images):
|
||||
merged_image.paste(image, (1024 * i, 0))
|
||||
|
||||
# compare against reference image
|
||||
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
|
||||
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=12, min_ssim=0.39, min_dinov2=0.95)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -2624,7 +2624,7 @@ def test_multi_upscaler(
|
|||
generator = torch.Generator(device=multi_upscaler.device)
|
||||
generator.manual_seed(37)
|
||||
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
|
||||
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
|
||||
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=25, min_ssim=0.85, min_dinov2=0.96)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
|
|
@ -110,7 +110,7 @@ def test_guide_adapting_sdxl_vanilla(
|
|||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -152,7 +152,7 @@ def test_guide_adapting_sdxl_single_lora(
|
|||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -196,7 +196,7 @@ def test_guide_adapting_sdxl_multiple_loras(
|
|||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=38, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -256,7 +256,7 @@ def test_guide_adapting_sdxl_loras_ip_adapter(
|
|||
)
|
||||
|
||||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=29, min_ssim=0.98)
|
||||
|
||||
|
||||
# We do not (yet) test the last example using T2i-Adapter with Zoe Depth.
|
||||
|
|
|
@ -93,7 +93,7 @@ def test_lightning_base_4step(
|
|||
)
|
||||
predicted_image = sdxl.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -144,7 +144,7 @@ def test_lightning_base_1step(
|
|||
)
|
||||
predicted_image = sdxl.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)
|
||||
|
||||
|
||||
@no_grad()
|
||||
|
@ -198,4 +198,4 @@ def test_lightning_lora_4step(
|
|||
)
|
||||
predicted_image = sdxl.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=40, min_ssim=0.98)
|
||||
|
|
|
@ -238,15 +238,15 @@ def test_predictor(
|
|||
assert torch.allclose(
|
||||
reference_low_res_mask_hq,
|
||||
refiners_low_res_mask_hq,
|
||||
atol=4e-3,
|
||||
atol=1e-2,
|
||||
)
|
||||
assert ( # absolute diff in number of pixels
|
||||
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 10
|
||||
)
|
||||
assert (
|
||||
torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() <= 2
|
||||
) # The diff on the logits above leads to an absolute diff of 2 pixel on the high res masks
|
||||
assert torch.allclose(
|
||||
iou_predictions_np,
|
||||
torch.max(iou_predictions),
|
||||
atol=1e-5,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -317,8 +317,9 @@ def test_predictor(
|
|||
for i in range(3):
|
||||
mask_prediction = masks[i].cpu()
|
||||
facebook_mask = torch.as_tensor(facebook_masks[i])
|
||||
assert isclose(intersection_over_union(mask_prediction, facebook_mask), 1.0, rel_tol=5e-05)
|
||||
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)
|
||||
iou = intersection_over_union(mask_prediction, facebook_mask)
|
||||
assert isclose(iou, 1.0, rel_tol=5e-04), f"iou: {iou}"
|
||||
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-04)
|
||||
|
||||
|
||||
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:
|
||||
|
|
|
@ -1,25 +1,65 @@
|
|||
from functools import cache
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
import numpy as np
|
||||
import piq # type: ignore
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
||||
|
||||
from refiners.conversion.models import dinov2
|
||||
from refiners.fluxion.utils import image_to_tensor
|
||||
from refiners.foundationals.dinov2 import DINOv2_small
|
||||
|
||||
def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]:
|
||||
x1, x2 = (
|
||||
torch.tensor(np.array(x).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0 for x in (img_1, img_2)
|
||||
|
||||
@cache
|
||||
def get_small_dinov2_model() -> DINOv2_small:
|
||||
model = DINOv2_small()
|
||||
model.load_from_safetensors(
|
||||
dinov2.small.converted.local_path
|
||||
if dinov2.small.converted.local_path.exists()
|
||||
else dinov2.small.converted.hf_cache_path
|
||||
)
|
||||
return (piq.psnr(x1, x2), piq.ssim(x1, x2).item()) # type: ignore
|
||||
return model
|
||||
|
||||
|
||||
def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int = 45, min_ssim: float = 0.99):
|
||||
psnr, ssim = compare_images(img_1, img_2)
|
||||
assert (psnr >= min_psnr) and (
|
||||
ssim >= min_ssim
|
||||
), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}"
|
||||
def compare_images(
|
||||
img_1: Image.Image,
|
||||
img_2: Image.Image,
|
||||
) -> tuple[float, float, float]:
|
||||
x1 = image_to_tensor(img_1)
|
||||
x2 = image_to_tensor(img_2)
|
||||
|
||||
psnr = piq.psnr(x1, x2) # type: ignore
|
||||
ssim = piq.ssim(x1, x2) # type: ignore
|
||||
|
||||
dinov2_model = get_small_dinov2_model()
|
||||
dinov2 = torch.nn.functional.cosine_similarity(
|
||||
dinov2_model(x1)[:, 0],
|
||||
dinov2_model(x2)[:, 0],
|
||||
)
|
||||
|
||||
return psnr.item(), ssim.item(), dinov2.item() # type: ignore
|
||||
|
||||
|
||||
def ensure_similar_images(
|
||||
img_1: Image.Image,
|
||||
img_2: Image.Image,
|
||||
min_psnr: int = 45,
|
||||
min_ssim: float = 0.99,
|
||||
min_dinov2: float = 0.99,
|
||||
) -> None:
|
||||
psnr, ssim, dinov2 = compare_images(img_1, img_2)
|
||||
if (psnr < min_psnr) or (ssim < min_ssim) or (dinov2 < min_dinov2):
|
||||
raise AssertionError(
|
||||
dedent(f"""
|
||||
Images are not similar enough!
|
||||
- PSNR: {psnr:08.05f} (required at least {min_psnr:08.05f})
|
||||
- SSIM: {ssim:08.06f} (required at least {min_ssim:08.06f})
|
||||
- DINO: {dinov2:08.06f} (required at least {min_dinov2:08.06f})
|
||||
""").strip()
|
||||
)
|
||||
|
||||
|
||||
class T5TextEmbedder(nn.Module):
|
||||
|
|
Loading…
Reference in a new issue