2024-08-10 13:17:36 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from warnings import warn
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
from PIL import Image
|
|
|
|
from tests.utils import ensure_similar_images
|
|
|
|
|
|
|
|
from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image
|
|
|
|
from refiners.foundationals.swin.mvanet import MVANet
|
|
|
|
|
|
|
|
|
|
|
|
def _img_open(path: Path) -> Image.Image:
|
|
|
|
return Image.open(path) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def ref_path(test_e2e_path: Path) -> Path:
|
|
|
|
return test_e2e_path / "test_mvanet_ref"
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def ref_cactus(ref_path: Path) -> Image.Image:
|
|
|
|
return _img_open(ref_path / "cactus.png").convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def expected_cactus_mask(ref_path: Path) -> Image.Image:
|
|
|
|
return _img_open(ref_path / "expected_cactus_mask.png")
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def mvanet_weights(test_weights_path: Path) -> Path:
|
|
|
|
weights = test_weights_path / "mvanet" / "mvanet.safetensors"
|
|
|
|
if not weights.is_file():
|
|
|
|
warn(f"could not find weights at {test_weights_path}, skipping")
|
|
|
|
pytest.skip(allow_module_level=True)
|
|
|
|
return weights
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def mvanet_model(mvanet_weights: Path, test_device: torch.device) -> MVANet:
|
|
|
|
model = MVANet(device=test_device).eval() # .eval() is important!
|
|
|
|
model.load_from_safetensors(mvanet_weights)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@no_grad()
|
|
|
|
def test_mvanet(
|
|
|
|
mvanet_model: MVANet,
|
|
|
|
ref_cactus: Image.Image,
|
|
|
|
expected_cactus_mask: Image.Image,
|
|
|
|
test_device: torch.device,
|
|
|
|
):
|
|
|
|
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
|
|
|
|
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
|
|
|
|
prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid()
|
|
|
|
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
|
|
|
|
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
|
2024-08-27 16:02:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
@no_grad()
|
|
|
|
def test_mvanet_to(
|
|
|
|
mvanet_weights: Path,
|
|
|
|
ref_cactus: Image.Image,
|
|
|
|
expected_cactus_mask: Image.Image,
|
|
|
|
test_device: torch.device,
|
|
|
|
):
|
|
|
|
if test_device.type == "cpu":
|
|
|
|
warn("not running on CPU, skipping")
|
|
|
|
pytest.skip()
|
|
|
|
|
|
|
|
model = MVANet(device=torch.device("cpu")).eval()
|
|
|
|
model.load_from_safetensors(mvanet_weights)
|
|
|
|
model.to(test_device)
|
|
|
|
|
|
|
|
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
|
|
|
|
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
|
|
|
|
prediction: torch.Tensor = model(in_t.to(test_device)).sigmoid()
|
|
|
|
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
|
|
|
|
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
|