add box segmenter solution
Some checks failed
CI / lint_and_typecheck (push) Has been cancelled
Deploy docs to GitHub Pages / Deploy docs (push) Has been cancelled
Spell checker / Spell check (push) Has been cancelled

This commit is contained in:
Pierre Chapuis 2024-08-27 11:44:44 +02:00
parent c8ff6d95f0
commit 7ca1774b5f
12 changed files with 174 additions and 1 deletions

View file

@ -21,6 +21,8 @@ ______________________________________________________________________
## Latest News 🔥
- Added the Box Segmenter all-in-one solution ([model](https://huggingface.co/finegrain/finegrain-box-segmenter), [HF Space](https://huggingface.co/spaces/finegrain/finegrain-object-cutter))
- Added [MVANet](https://arxiv.org/abs/2404.07445) for high resolution segmentation
- Added [IC-Light](https://github.com/lllyasviel/IC-Light) to manipulate the illumination of images
- Added Multi Upscaler for high-resolution image generation, inspired from [Clarity Upscaler](https://github.com/philz1337x/clarity-upscaler) ([HF Space](https://huggingface.co/spaces/finegrain/enhancer))
- Added [HQ-SAM](https://arxiv.org/abs/2306.01567) for high quality mask prediction with Segment Anything

View file

@ -67,6 +67,9 @@ doc = [
"mkdocstrings[python]>=0.24.0",
"mkdocs-literate-nav>=0.6.1",
]
solutions = [
"huggingface-hub>=0.24.6",
]
[build-system]
requires = ["hatchling"]

View file

@ -100,9 +100,10 @@ gitpython==3.1.43
# via wandb
griffe==0.48.0
# via mkdocstrings-python
huggingface-hub==0.24.5
huggingface-hub==0.24.6
# via datasets
# via diffusers
# via refiners
# via timm
# via tokenizers
# via transformers

View file

@ -466,6 +466,15 @@ def download_mvanet():
check_hash(dest_filename, "b915d492")
def download_box_segmenter():
download_file(
"https://huggingface.co/finegrain/finegrain-box-segmenter/resolve/v0.1/model.safetensors",
dest_folder=test_weights_dir,
filename="finegrain-box-segmenter-v0-1.safetensors",
expected_hash="e0450e8c",
)
def printg(msg: str):
"""print in green color"""
print("\033[92m" + msg + "\033[0m")
@ -861,6 +870,7 @@ def download_all():
download_sdxl_lightning_lora()
download_ic_light()
download_mvanet()
download_box_segmenter()
def convert_all():

View file

@ -0,0 +1,3 @@
from .box_segmenter import BoxSegmenter
__all__ = ["BoxSegmenter"]

View file

@ -0,0 +1,79 @@
from pathlib import Path
import torch
from PIL import Image
from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image
from refiners.foundationals.swin.mvanet import MVANet
BoundingBox = tuple[int, int, int, int]
class BoxSegmenter:
def __init__(
self,
*,
margin: float = 0.05,
weights: Path | str | dict[str, torch.Tensor] | None = None,
device: torch.device | str = "cpu",
):
assert margin >= 0
self.margin = margin
self.device = torch.device(device)
self.model = MVANet(device=self.device).eval()
if weights is None:
from huggingface_hub.file_download import hf_hub_download # type: ignore[reportUnknownVariableType]
weights = hf_hub_download(
repo_id="finegrain/finegrain-box-segmenter",
filename="model.safetensors",
revision="v0.1",
)
if isinstance(weights, dict):
self.model.load_state_dict(weights)
else:
self.model.load_from_safetensors(weights)
def __call__(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image:
return self.run(img, box_prompt)
def add_margin(self, box: BoundingBox) -> BoundingBox:
x0, y0, x1, y1 = box
mx = int((x1 - x0) * self.margin)
my = int((y1 - y0) * self.margin)
return (x0 - mx, y0 - my, x1 + mx, y1 + my)
@staticmethod
def crop_pad(img: Image.Image, box: BoundingBox) -> Image.Image:
img = img.convert("RGB")
x0, y0, x1, y1 = box
px0, py0, px1, py1 = (max(0, -x0), max(0, -y0), max(0, x1 - img.width), max(0, y1 - img.height))
if (px0, py0, px1, py1) == (0, 0, 0, 0):
return img.crop(box)
padded = Image.new("RGB", (img.width + px0 + px1, img.height + py0 + py1))
padded.paste(img, (px0, py0))
return padded.crop((x0 + px0, y0 + py0, x1 + px0, y1 + py0))
def predict(self, img: Image.Image) -> Image.Image:
in_t = image_to_tensor(img.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
with no_grad():
prediction: torch.Tensor = self.model(in_t.to(self.device)).sigmoid()
return tensor_to_image(prediction).resize(img.size, Image.Resampling.BILINEAR)
def run(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image:
if box_prompt is None:
box_prompt = (0, 0, img.width, img.height)
box = self.add_margin(box_prompt)
cropped = self.crop_pad(img, box)
prediction = self.predict(cropped)
out = Image.new("L", (img.width, img.height))
out.paste(prediction, box)
return out

View file

@ -0,0 +1,72 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
from tests.utils import ensure_similar_images
from refiners.solutions import BoxSegmenter
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_solutions_ref"
@pytest.fixture(scope="module")
def ref_shelves(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "shelves.jpg").convert("RGB")
@pytest.fixture
def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_plant_mask.png")
@pytest.fixture
def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_spray_mask.png")
@pytest.fixture
def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
@pytest.fixture(scope="module")
def box_segmenter_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
def test_box_segmenter(
box_segmenter_weights: Path,
ref_shelves: Image.Image,
expected_box_segmenter_plant_mask: Image.Image,
expected_box_segmenter_spray_mask: Image.Image,
expected_box_segmenter_spray_cropped_mask: Image.Image,
test_device: torch.device,
):
segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device)
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))
spray_box = (461, 542, 594, 823)
spray_mask = segmenter(ref_shelves, box_prompt=spray_box)
ensure_similar_images(spray_mask.convert("RGB"), expected_box_segmenter_spray_mask.convert("RGB"))
# Test left and bottom padding.
off_l, off_b = 11, 7
shelves_cropped = ref_shelves.crop((spray_box[0] - off_l, 0, ref_shelves.width, spray_box[3] + off_b))
spray_cropped_box = (off_l, spray_box[1], spray_box[2] - spray_box[0] + off_l, spray_box[3])
spray_cropped_mask = segmenter(shelves_cropped, box_prompt=spray_cropped_box)
ensure_similar_images(spray_cropped_mask.convert("RGB"), expected_box_segmenter_spray_cropped_mask.convert("RGB"))

View file

@ -0,0 +1,3 @@
`shelves.jpg` is found here: https://www.freepik.com/free-photo/front-view-shelves-with-plants_6446859.htm
`expected_box_segmenter_plant_mask.png`, `expected_box_segmenter_spray_mask.png` and `expected_box_segmenter_spray_cropped_mask.png` are generated with Refiners.

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 126 KiB