refiners/tests/utils.py

19 lines
670 B
Python
Raw Normal View History

2023-08-04 13:28:41 +00:00
import torch
import piq # type: ignore
import numpy as np
from PIL import Image
def compare_images(img_1: Image.Image, img_2: Image.Image) -> tuple[int, float]:
x1, x2 = (
torch.tensor(np.array(x).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0 for x in (img_1, img_2)
)
return (piq.psnr(x1, x2), piq.ssim(x1, x2).item()) # type: ignore
def ensure_similar_images(img_1: Image.Image, img_2: Image.Image, min_psnr: int = 45, min_ssim: float = 0.99):
psnr, ssim = compare_images(img_1, img_2)
assert (psnr >= min_psnr) and (
ssim >= min_ssim
), f"PSNR {psnr} / SSIM {ssim}, expected at least {min_psnr} / {min_ssim}"