refiners/tests/e2e/test_solutions.py
Pierre Chapuis 7ca1774b5f
Some checks failed
CI / lint_and_typecheck (push) Has been cancelled
Deploy docs to GitHub Pages / Deploy docs (push) Has been cancelled
Spell checker / Spell check (push) Has been cancelled
add box segmenter solution
2024-08-30 09:26:53 +02:00

73 lines
2.5 KiB
Python

from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
from tests.utils import ensure_similar_images
from refiners.solutions import BoxSegmenter
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_solutions_ref"
@pytest.fixture(scope="module")
def ref_shelves(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "shelves.jpg").convert("RGB")
@pytest.fixture
def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_plant_mask.png")
@pytest.fixture
def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_spray_mask.png")
@pytest.fixture
def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
@pytest.fixture(scope="module")
def box_segmenter_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
def test_box_segmenter(
box_segmenter_weights: Path,
ref_shelves: Image.Image,
expected_box_segmenter_plant_mask: Image.Image,
expected_box_segmenter_spray_mask: Image.Image,
expected_box_segmenter_spray_cropped_mask: Image.Image,
test_device: torch.device,
):
segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device)
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))
spray_box = (461, 542, 594, 823)
spray_mask = segmenter(ref_shelves, box_prompt=spray_box)
ensure_similar_images(spray_mask.convert("RGB"), expected_box_segmenter_spray_mask.convert("RGB"))
# Test left and bottom padding.
off_l, off_b = 11, 7
shelves_cropped = ref_shelves.crop((spray_box[0] - off_l, 0, ref_shelves.width, spray_box[3] + off_b))
spray_cropped_box = (off_l, spray_box[1], spray_box[2] - spray_box[0] + off_l, spray_box[3])
spray_cropped_mask = segmenter(shelves_cropped, box_prompt=spray_cropped_box)
ensure_similar_images(spray_cropped_mask.convert("RGB"), expected_box_segmenter_spray_cropped_mask.convert("RGB"))