refiners/tests/utils.py
2024-10-14 15:12:59 +02:00

75 lines
2.8 KiB
Python

from pathlib import Path
import numpy as np
import piq # type: ignore
import torch
import torch.nn as nn
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]:
x1, x2 = (
torch.tensor(np.array(x).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0 for x in (img_1, img_2)
)
return (piq.psnr(x1, x2), piq.ssim(x1, x2).item()) # type: ignore
def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int = 45, min_ssim: float = 0.99):
psnr, ssim = compare_images(img_1, img_2)
assert (psnr >= min_psnr) and (
ssim >= min_ssim
), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}"
class T5TextEmbedder(nn.Module):
def __init__(
self,
pretrained_path: Path | str,
max_length: int | None = None,
local_files_only: bool = False,
) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
self.model: nn.Module = T5EncoderModel.from_pretrained( # type: ignore
pretrained_path,
local_files_only=local_files_only,
)
self.tokenizer: transformers.T5Tokenizer = T5Tokenizer.from_pretrained( # type: ignore
pretrained_path,
local_files_only=local_files_only,
)
self.max_length = max_length
def forward(
self,
caption: str,
text_input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
max_length: int | None = None,
) -> torch.Tensor:
if max_length is None:
max_length = self.max_length
if text_input_ids is None or attention_mask is None:
if max_length is not None:
text_inputs = self.tokenizer( # type: ignore
caption,
return_tensors="pt",
add_special_tokens=True,
max_length=max_length,
padding="max_length",
truncation=True,
)
else:
text_inputs = self.tokenizer(caption, return_tensors="pt", add_special_tokens=True) # type: ignore
_text_input_ids: torch.Tensor = text_inputs.input_ids.to(self.model.device) # type: ignore
_attention_mask: torch.Tensor = text_inputs.attention_mask.to(self.model.device) # type: ignore
else:
_text_input_ids: torch.Tensor = text_input_ids.to(self.model.device) # type: ignore
_attention_mask: torch.Tensor = attention_mask.to(self.model.device) # type: ignore
outputs = self.model(_text_input_ids, attention_mask=_attention_mask)
embeddings = outputs.last_hidden_state
return embeddings