mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
|
from pathlib import Path
|
||
|
from warnings import warn
|
||
|
|
||
|
import pytest
|
||
|
import torch
|
||
|
from PIL import Image
|
||
|
from tests.utils import ensure_similar_images
|
||
|
|
||
|
from refiners.solutions import BoxSegmenter
|
||
|
|
||
|
|
||
|
def _img_open(path: Path) -> Image.Image:
|
||
|
return Image.open(path) # type: ignore
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def ref_path(test_e2e_path: Path) -> Path:
|
||
|
return test_e2e_path / "test_solutions_ref"
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def ref_shelves(ref_path: Path) -> Image.Image:
|
||
|
return _img_open(ref_path / "shelves.jpg").convert("RGB")
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image:
|
||
|
return _img_open(ref_path / "expected_box_segmenter_plant_mask.png")
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image:
|
||
|
return _img_open(ref_path / "expected_box_segmenter_spray_mask.png")
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
|
||
|
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def box_segmenter_weights(test_weights_path: Path) -> Path:
|
||
|
weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors"
|
||
|
if not weights.is_file():
|
||
|
warn(f"could not find weights at {test_weights_path}, skipping")
|
||
|
pytest.skip(allow_module_level=True)
|
||
|
return weights
|
||
|
|
||
|
|
||
|
def test_box_segmenter(
|
||
|
box_segmenter_weights: Path,
|
||
|
ref_shelves: Image.Image,
|
||
|
expected_box_segmenter_plant_mask: Image.Image,
|
||
|
expected_box_segmenter_spray_mask: Image.Image,
|
||
|
expected_box_segmenter_spray_cropped_mask: Image.Image,
|
||
|
test_device: torch.device,
|
||
|
):
|
||
|
segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device)
|
||
|
|
||
|
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
|
||
|
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))
|
||
|
|
||
|
spray_box = (461, 542, 594, 823)
|
||
|
spray_mask = segmenter(ref_shelves, box_prompt=spray_box)
|
||
|
ensure_similar_images(spray_mask.convert("RGB"), expected_box_segmenter_spray_mask.convert("RGB"))
|
||
|
|
||
|
# Test left and bottom padding.
|
||
|
off_l, off_b = 11, 7
|
||
|
shelves_cropped = ref_shelves.crop((spray_box[0] - off_l, 0, ref_shelves.width, spray_box[3] + off_b))
|
||
|
spray_cropped_box = (off_l, spray_box[1], spray_box[2] - spray_box[0] + off_l, spray_box[3])
|
||
|
spray_cropped_mask = segmenter(shelves_cropped, box_prompt=spray_cropped_box)
|
||
|
ensure_similar_images(spray_cropped_mask.convert("RGB"), expected_box_segmenter_spray_cropped_mask.convert("RGB"))
|