diff --git a/README.md b/README.md index fd19fcc..e27a87f 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,8 @@ ______________________________________________________________________ ## Latest News 🔥 +- Added the Box Segmenter all-in-one solution ([model](https://huggingface.co/finegrain/finegrain-box-segmenter), [HF Space](https://huggingface.co/spaces/finegrain/finegrain-object-cutter)) +- Added [MVANet](https://arxiv.org/abs/2404.07445) for high resolution segmentation - Added [IC-Light](https://github.com/lllyasviel/IC-Light) to manipulate the illumination of images - Added Multi Upscaler for high-resolution image generation, inspired from [Clarity Upscaler](https://github.com/philz1337x/clarity-upscaler) ([HF Space](https://huggingface.co/spaces/finegrain/enhancer)) - Added [HQ-SAM](https://arxiv.org/abs/2306.01567) for high quality mask prediction with Segment Anything diff --git a/pyproject.toml b/pyproject.toml index ee6909f..b27324b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,9 @@ doc = [ "mkdocstrings[python]>=0.24.0", "mkdocs-literate-nav>=0.6.1", ] +solutions = [ + "huggingface-hub>=0.24.6", +] [build-system] requires = ["hatchling"] diff --git a/requirements.lock b/requirements.lock index f9ec260..e1c7f47 100644 --- a/requirements.lock +++ b/requirements.lock @@ -100,9 +100,10 @@ gitpython==3.1.43 # via wandb griffe==0.48.0 # via mkdocstrings-python -huggingface-hub==0.24.5 +huggingface-hub==0.24.6 # via datasets # via diffusers + # via refiners # via timm # via tokenizers # via transformers diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 6f27a67..e09f477 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -466,6 +466,15 @@ def download_mvanet(): check_hash(dest_filename, "b915d492") +def download_box_segmenter(): + download_file( + "https://huggingface.co/finegrain/finegrain-box-segmenter/resolve/v0.1/model.safetensors", + dest_folder=test_weights_dir, + filename="finegrain-box-segmenter-v0-1.safetensors", + expected_hash="e0450e8c", + ) + + def printg(msg: str): """print in green color""" print("\033[92m" + msg + "\033[0m") @@ -861,6 +870,7 @@ def download_all(): download_sdxl_lightning_lora() download_ic_light() download_mvanet() + download_box_segmenter() def convert_all(): diff --git a/src/refiners/solutions/__init__.py b/src/refiners/solutions/__init__.py new file mode 100644 index 0000000..e88d8ad --- /dev/null +++ b/src/refiners/solutions/__init__.py @@ -0,0 +1,3 @@ +from .box_segmenter import BoxSegmenter + +__all__ = ["BoxSegmenter"] diff --git a/src/refiners/solutions/box_segmenter.py b/src/refiners/solutions/box_segmenter.py new file mode 100644 index 0000000..26a7b18 --- /dev/null +++ b/src/refiners/solutions/box_segmenter.py @@ -0,0 +1,79 @@ +from pathlib import Path + +import torch +from PIL import Image + +from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image +from refiners.foundationals.swin.mvanet import MVANet + +BoundingBox = tuple[int, int, int, int] + + +class BoxSegmenter: + def __init__( + self, + *, + margin: float = 0.05, + weights: Path | str | dict[str, torch.Tensor] | None = None, + device: torch.device | str = "cpu", + ): + assert margin >= 0 + self.margin = margin + + self.device = torch.device(device) + self.model = MVANet(device=self.device).eval() + + if weights is None: + from huggingface_hub.file_download import hf_hub_download # type: ignore[reportUnknownVariableType] + + weights = hf_hub_download( + repo_id="finegrain/finegrain-box-segmenter", + filename="model.safetensors", + revision="v0.1", + ) + + if isinstance(weights, dict): + self.model.load_state_dict(weights) + else: + self.model.load_from_safetensors(weights) + + def __call__(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image: + return self.run(img, box_prompt) + + def add_margin(self, box: BoundingBox) -> BoundingBox: + x0, y0, x1, y1 = box + mx = int((x1 - x0) * self.margin) + my = int((y1 - y0) * self.margin) + return (x0 - mx, y0 - my, x1 + mx, y1 + my) + + @staticmethod + def crop_pad(img: Image.Image, box: BoundingBox) -> Image.Image: + img = img.convert("RGB") + + x0, y0, x1, y1 = box + px0, py0, px1, py1 = (max(0, -x0), max(0, -y0), max(0, x1 - img.width), max(0, y1 - img.height)) + if (px0, py0, px1, py1) == (0, 0, 0, 0): + return img.crop(box) + + padded = Image.new("RGB", (img.width + px0 + px1, img.height + py0 + py1)) + padded.paste(img, (px0, py0)) + return padded.crop((x0 + px0, y0 + py0, x1 + px0, y1 + py0)) + + def predict(self, img: Image.Image) -> Image.Image: + in_t = image_to_tensor(img.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze() + in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0) + with no_grad(): + prediction: torch.Tensor = self.model(in_t.to(self.device)).sigmoid() + return tensor_to_image(prediction).resize(img.size, Image.Resampling.BILINEAR) + + def run(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image: + if box_prompt is None: + box_prompt = (0, 0, img.width, img.height) + + box = self.add_margin(box_prompt) + cropped = self.crop_pad(img, box) + prediction = self.predict(cropped) + + out = Image.new("L", (img.width, img.height)) + out.paste(prediction, box) + return out diff --git a/tests/e2e/test_solutions.py b/tests/e2e/test_solutions.py new file mode 100644 index 0000000..b9a0769 --- /dev/null +++ b/tests/e2e/test_solutions.py @@ -0,0 +1,72 @@ +from pathlib import Path +from warnings import warn + +import pytest +import torch +from PIL import Image +from tests.utils import ensure_similar_images + +from refiners.solutions import BoxSegmenter + + +def _img_open(path: Path) -> Image.Image: + return Image.open(path) # type: ignore + + +@pytest.fixture(scope="module") +def ref_path(test_e2e_path: Path) -> Path: + return test_e2e_path / "test_solutions_ref" + + +@pytest.fixture(scope="module") +def ref_shelves(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "shelves.jpg").convert("RGB") + + +@pytest.fixture +def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_box_segmenter_plant_mask.png") + + +@pytest.fixture +def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_box_segmenter_spray_mask.png") + + +@pytest.fixture +def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png") + + +@pytest.fixture(scope="module") +def box_segmenter_weights(test_weights_path: Path) -> Path: + weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors" + if not weights.is_file(): + warn(f"could not find weights at {test_weights_path}, skipping") + pytest.skip(allow_module_level=True) + return weights + + +def test_box_segmenter( + box_segmenter_weights: Path, + ref_shelves: Image.Image, + expected_box_segmenter_plant_mask: Image.Image, + expected_box_segmenter_spray_mask: Image.Image, + expected_box_segmenter_spray_cropped_mask: Image.Image, + test_device: torch.device, +): + segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device) + + plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368)) + ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB")) + + spray_box = (461, 542, 594, 823) + spray_mask = segmenter(ref_shelves, box_prompt=spray_box) + ensure_similar_images(spray_mask.convert("RGB"), expected_box_segmenter_spray_mask.convert("RGB")) + + # Test left and bottom padding. + off_l, off_b = 11, 7 + shelves_cropped = ref_shelves.crop((spray_box[0] - off_l, 0, ref_shelves.width, spray_box[3] + off_b)) + spray_cropped_box = (off_l, spray_box[1], spray_box[2] - spray_box[0] + off_l, spray_box[3]) + spray_cropped_mask = segmenter(shelves_cropped, box_prompt=spray_cropped_box) + ensure_similar_images(spray_cropped_mask.convert("RGB"), expected_box_segmenter_spray_cropped_mask.convert("RGB")) diff --git a/tests/e2e/test_solutions_ref/README.md b/tests/e2e/test_solutions_ref/README.md new file mode 100644 index 0000000..231c11d --- /dev/null +++ b/tests/e2e/test_solutions_ref/README.md @@ -0,0 +1,3 @@ +`shelves.jpg` is found here: https://www.freepik.com/free-photo/front-view-shelves-with-plants_6446859.htm + +`expected_box_segmenter_plant_mask.png`, `expected_box_segmenter_spray_mask.png` and `expected_box_segmenter_spray_cropped_mask.png` are generated with Refiners. diff --git a/tests/e2e/test_solutions_ref/expected_box_segmenter_plant_mask.png b/tests/e2e/test_solutions_ref/expected_box_segmenter_plant_mask.png new file mode 100644 index 0000000..cd795dd Binary files /dev/null and b/tests/e2e/test_solutions_ref/expected_box_segmenter_plant_mask.png differ diff --git a/tests/e2e/test_solutions_ref/expected_box_segmenter_spray_cropped_mask.png b/tests/e2e/test_solutions_ref/expected_box_segmenter_spray_cropped_mask.png new file mode 100644 index 0000000..fe7e0c2 Binary files /dev/null and b/tests/e2e/test_solutions_ref/expected_box_segmenter_spray_cropped_mask.png differ diff --git a/tests/e2e/test_solutions_ref/expected_box_segmenter_spray_mask.png b/tests/e2e/test_solutions_ref/expected_box_segmenter_spray_mask.png new file mode 100644 index 0000000..ccaeddf Binary files /dev/null and b/tests/e2e/test_solutions_ref/expected_box_segmenter_spray_mask.png differ diff --git a/tests/e2e/test_solutions_ref/shelves.jpg b/tests/e2e/test_solutions_ref/shelves.jpg new file mode 100644 index 0000000..c3ad71f Binary files /dev/null and b/tests/e2e/test_solutions_ref/shelves.jpg differ