refiners/tests/utils.py

115 lines
3.8 KiB
Python
Raw Normal View History

from functools import cache
from pathlib import Path
from textwrap import dedent
import piq # type: ignore
import torch
import torch.nn as nn
2023-08-04 13:28:41 +00:00
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
2023-08-04 13:28:41 +00:00
from refiners.conversion.models import dinov2
from refiners.fluxion.utils import image_to_tensor
from refiners.foundationals.dinov2 import DINOv2_small
2023-08-04 13:28:41 +00:00
@cache
def get_small_dinov2_model() -> DINOv2_small:
model = DINOv2_small()
model.load_from_safetensors(
dinov2.small.converted.local_path
if dinov2.small.converted.local_path.exists()
else dinov2.small.converted.hf_cache_path
2023-08-04 13:28:41 +00:00
)
return model
def compare_images(
img_1: Image.Image,
img_2: Image.Image,
) -> tuple[float, float, float]:
x1 = image_to_tensor(img_1)
x2 = image_to_tensor(img_2)
psnr = piq.psnr(x1, x2) # type: ignore
ssim = piq.ssim(x1, x2) # type: ignore
dinov2_model = get_small_dinov2_model()
dinov2 = torch.nn.functional.cosine_similarity(
dinov2_model(x1)[:, 0],
dinov2_model(x2)[:, 0],
)
return psnr.item(), ssim.item(), dinov2.item() # type: ignore
2023-08-04 13:28:41 +00:00
def ensure_similar_images(
img_1: Image.Image,
img_2: Image.Image,
min_psnr: int = 45,
min_ssim: float = 0.99,
min_dinov2: float = 0.99,
) -> None:
psnr, ssim, dinov2 = compare_images(img_1, img_2)
if (psnr < min_psnr) or (ssim < min_ssim) or (dinov2 < min_dinov2):
raise AssertionError(
dedent(f"""
Images are not similar enough!
- PSNR: {psnr:08.05f} (required at least {min_psnr:08.05f})
- SSIM: {ssim:08.06f} (required at least {min_ssim:08.06f})
- DINO: {dinov2:08.06f} (required at least {min_dinov2:08.06f})
""").strip()
)
class T5TextEmbedder(nn.Module):
def __init__(
2024-10-09 09:28:34 +00:00
self,
pretrained_path: Path | str,
max_length: int | None = None,
local_files_only: bool = False,
) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
2024-10-09 09:28:34 +00:00
self.model: nn.Module = T5EncoderModel.from_pretrained( # type: ignore
pretrained_path,
local_files_only=local_files_only,
)
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained( # type: ignore
pretrained_path,
local_files_only=local_files_only,
)
self.max_length = max_length
def forward(
self,
caption: str,
text_input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
max_length: int | None = None,
) -> torch.Tensor:
if max_length is None:
max_length = self.max_length
if text_input_ids is None or attention_mask is None:
if max_length is not None:
text_inputs = self.tokenizer( # type: ignore
caption,
return_tensors="pt",
add_special_tokens=True,
max_length=max_length,
padding="max_length",
truncation=True,
)
else:
text_inputs = self.tokenizer(caption, return_tensors="pt", add_special_tokens=True) # type: ignore
_text_input_ids: torch.Tensor = text_inputs.input_ids.to(self.model.device) # type: ignore
_attention_mask: torch.Tensor = text_inputs.attention_mask.to(self.model.device) # type: ignore
else:
_text_input_ids: torch.Tensor = text_input_ids.to(self.model.device) # type: ignore
_attention_mask: torch.Tensor = attention_mask.to(self.model.device) # type: ignore
outputs = self.model(_text_input_ids, attention_mask=_attention_mask)
embeddings = outputs.last_hidden_state
return embeddings