mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
add box segmenter solution
This commit is contained in:
parent
c8ff6d95f0
commit
7ca1774b5f
|
@ -21,6 +21,8 @@ ______________________________________________________________________
|
|||
|
||||
## Latest News 🔥
|
||||
|
||||
- Added the Box Segmenter all-in-one solution ([model](https://huggingface.co/finegrain/finegrain-box-segmenter), [HF Space](https://huggingface.co/spaces/finegrain/finegrain-object-cutter))
|
||||
- Added [MVANet](https://arxiv.org/abs/2404.07445) for high resolution segmentation
|
||||
- Added [IC-Light](https://github.com/lllyasviel/IC-Light) to manipulate the illumination of images
|
||||
- Added Multi Upscaler for high-resolution image generation, inspired from [Clarity Upscaler](https://github.com/philz1337x/clarity-upscaler) ([HF Space](https://huggingface.co/spaces/finegrain/enhancer))
|
||||
- Added [HQ-SAM](https://arxiv.org/abs/2306.01567) for high quality mask prediction with Segment Anything
|
||||
|
|
|
@ -67,6 +67,9 @@ doc = [
|
|||
"mkdocstrings[python]>=0.24.0",
|
||||
"mkdocs-literate-nav>=0.6.1",
|
||||
]
|
||||
solutions = [
|
||||
"huggingface-hub>=0.24.6",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
|
|
@ -100,9 +100,10 @@ gitpython==3.1.43
|
|||
# via wandb
|
||||
griffe==0.48.0
|
||||
# via mkdocstrings-python
|
||||
huggingface-hub==0.24.5
|
||||
huggingface-hub==0.24.6
|
||||
# via datasets
|
||||
# via diffusers
|
||||
# via refiners
|
||||
# via timm
|
||||
# via tokenizers
|
||||
# via transformers
|
||||
|
|
|
@ -466,6 +466,15 @@ def download_mvanet():
|
|||
check_hash(dest_filename, "b915d492")
|
||||
|
||||
|
||||
def download_box_segmenter():
|
||||
download_file(
|
||||
"https://huggingface.co/finegrain/finegrain-box-segmenter/resolve/v0.1/model.safetensors",
|
||||
dest_folder=test_weights_dir,
|
||||
filename="finegrain-box-segmenter-v0-1.safetensors",
|
||||
expected_hash="e0450e8c",
|
||||
)
|
||||
|
||||
|
||||
def printg(msg: str):
|
||||
"""print in green color"""
|
||||
print("\033[92m" + msg + "\033[0m")
|
||||
|
@ -861,6 +870,7 @@ def download_all():
|
|||
download_sdxl_lightning_lora()
|
||||
download_ic_light()
|
||||
download_mvanet()
|
||||
download_box_segmenter()
|
||||
|
||||
|
||||
def convert_all():
|
||||
|
|
3
src/refiners/solutions/__init__.py
Normal file
3
src/refiners/solutions/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .box_segmenter import BoxSegmenter
|
||||
|
||||
__all__ = ["BoxSegmenter"]
|
79
src/refiners/solutions/box_segmenter.py
Normal file
79
src/refiners/solutions/box_segmenter.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image
|
||||
from refiners.foundationals.swin.mvanet import MVANet
|
||||
|
||||
BoundingBox = tuple[int, int, int, int]
|
||||
|
||||
|
||||
class BoxSegmenter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
margin: float = 0.05,
|
||||
weights: Path | str | dict[str, torch.Tensor] | None = None,
|
||||
device: torch.device | str = "cpu",
|
||||
):
|
||||
assert margin >= 0
|
||||
self.margin = margin
|
||||
|
||||
self.device = torch.device(device)
|
||||
self.model = MVANet(device=self.device).eval()
|
||||
|
||||
if weights is None:
|
||||
from huggingface_hub.file_download import hf_hub_download # type: ignore[reportUnknownVariableType]
|
||||
|
||||
weights = hf_hub_download(
|
||||
repo_id="finegrain/finegrain-box-segmenter",
|
||||
filename="model.safetensors",
|
||||
revision="v0.1",
|
||||
)
|
||||
|
||||
if isinstance(weights, dict):
|
||||
self.model.load_state_dict(weights)
|
||||
else:
|
||||
self.model.load_from_safetensors(weights)
|
||||
|
||||
def __call__(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image:
|
||||
return self.run(img, box_prompt)
|
||||
|
||||
def add_margin(self, box: BoundingBox) -> BoundingBox:
|
||||
x0, y0, x1, y1 = box
|
||||
mx = int((x1 - x0) * self.margin)
|
||||
my = int((y1 - y0) * self.margin)
|
||||
return (x0 - mx, y0 - my, x1 + mx, y1 + my)
|
||||
|
||||
@staticmethod
|
||||
def crop_pad(img: Image.Image, box: BoundingBox) -> Image.Image:
|
||||
img = img.convert("RGB")
|
||||
|
||||
x0, y0, x1, y1 = box
|
||||
px0, py0, px1, py1 = (max(0, -x0), max(0, -y0), max(0, x1 - img.width), max(0, y1 - img.height))
|
||||
if (px0, py0, px1, py1) == (0, 0, 0, 0):
|
||||
return img.crop(box)
|
||||
|
||||
padded = Image.new("RGB", (img.width + px0 + px1, img.height + py0 + py1))
|
||||
padded.paste(img, (px0, py0))
|
||||
return padded.crop((x0 + px0, y0 + py0, x1 + px0, y1 + py0))
|
||||
|
||||
def predict(self, img: Image.Image) -> Image.Image:
|
||||
in_t = image_to_tensor(img.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
|
||||
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
|
||||
with no_grad():
|
||||
prediction: torch.Tensor = self.model(in_t.to(self.device)).sigmoid()
|
||||
return tensor_to_image(prediction).resize(img.size, Image.Resampling.BILINEAR)
|
||||
|
||||
def run(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image:
|
||||
if box_prompt is None:
|
||||
box_prompt = (0, 0, img.width, img.height)
|
||||
|
||||
box = self.add_margin(box_prompt)
|
||||
cropped = self.crop_pad(img, box)
|
||||
prediction = self.predict(cropped)
|
||||
|
||||
out = Image.new("L", (img.width, img.height))
|
||||
out.paste(prediction, box)
|
||||
return out
|
72
tests/e2e/test_solutions.py
Normal file
72
tests/e2e/test_solutions.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
from refiners.solutions import BoxSegmenter
|
||||
|
||||
|
||||
def _img_open(path: Path) -> Image.Image:
|
||||
return Image.open(path) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_path(test_e2e_path: Path) -> Path:
|
||||
return test_e2e_path / "test_solutions_ref"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_shelves(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "shelves.jpg").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_box_segmenter_plant_mask.png")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_box_segmenter_spray_mask.png")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def box_segmenter_weights(test_weights_path: Path) -> Path:
|
||||
weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors"
|
||||
if not weights.is_file():
|
||||
warn(f"could not find weights at {test_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return weights
|
||||
|
||||
|
||||
def test_box_segmenter(
|
||||
box_segmenter_weights: Path,
|
||||
ref_shelves: Image.Image,
|
||||
expected_box_segmenter_plant_mask: Image.Image,
|
||||
expected_box_segmenter_spray_mask: Image.Image,
|
||||
expected_box_segmenter_spray_cropped_mask: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device)
|
||||
|
||||
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
|
||||
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))
|
||||
|
||||
spray_box = (461, 542, 594, 823)
|
||||
spray_mask = segmenter(ref_shelves, box_prompt=spray_box)
|
||||
ensure_similar_images(spray_mask.convert("RGB"), expected_box_segmenter_spray_mask.convert("RGB"))
|
||||
|
||||
# Test left and bottom padding.
|
||||
off_l, off_b = 11, 7
|
||||
shelves_cropped = ref_shelves.crop((spray_box[0] - off_l, 0, ref_shelves.width, spray_box[3] + off_b))
|
||||
spray_cropped_box = (off_l, spray_box[1], spray_box[2] - spray_box[0] + off_l, spray_box[3])
|
||||
spray_cropped_mask = segmenter(shelves_cropped, box_prompt=spray_cropped_box)
|
||||
ensure_similar_images(spray_cropped_mask.convert("RGB"), expected_box_segmenter_spray_cropped_mask.convert("RGB"))
|
3
tests/e2e/test_solutions_ref/README.md
Normal file
3
tests/e2e/test_solutions_ref/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
`shelves.jpg` is found here: https://www.freepik.com/free-photo/front-view-shelves-with-plants_6446859.htm
|
||||
|
||||
`expected_box_segmenter_plant_mask.png`, `expected_box_segmenter_spray_mask.png` and `expected_box_segmenter_spray_cropped_mask.png` are generated with Refiners.
|
Binary file not shown.
After Width: | Height: | Size: 19 KiB |
Binary file not shown.
After Width: | Height: | Size: 6.4 KiB |
Binary file not shown.
After Width: | Height: | Size: 7.1 KiB |
BIN
tests/e2e/test_solutions_ref/shelves.jpg
Normal file
BIN
tests/e2e/test_solutions_ref/shelves.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 126 KiB |
Loading…
Reference in a new issue