mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add box segmenter solution
This commit is contained in:
parent
c8ff6d95f0
commit
7ca1774b5f
|
@ -21,6 +21,8 @@ ______________________________________________________________________
|
||||||
|
|
||||||
## Latest News 🔥
|
## Latest News 🔥
|
||||||
|
|
||||||
|
- Added the Box Segmenter all-in-one solution ([model](https://huggingface.co/finegrain/finegrain-box-segmenter), [HF Space](https://huggingface.co/spaces/finegrain/finegrain-object-cutter))
|
||||||
|
- Added [MVANet](https://arxiv.org/abs/2404.07445) for high resolution segmentation
|
||||||
- Added [IC-Light](https://github.com/lllyasviel/IC-Light) to manipulate the illumination of images
|
- Added [IC-Light](https://github.com/lllyasviel/IC-Light) to manipulate the illumination of images
|
||||||
- Added Multi Upscaler for high-resolution image generation, inspired from [Clarity Upscaler](https://github.com/philz1337x/clarity-upscaler) ([HF Space](https://huggingface.co/spaces/finegrain/enhancer))
|
- Added Multi Upscaler for high-resolution image generation, inspired from [Clarity Upscaler](https://github.com/philz1337x/clarity-upscaler) ([HF Space](https://huggingface.co/spaces/finegrain/enhancer))
|
||||||
- Added [HQ-SAM](https://arxiv.org/abs/2306.01567) for high quality mask prediction with Segment Anything
|
- Added [HQ-SAM](https://arxiv.org/abs/2306.01567) for high quality mask prediction with Segment Anything
|
||||||
|
|
|
@ -67,6 +67,9 @@ doc = [
|
||||||
"mkdocstrings[python]>=0.24.0",
|
"mkdocstrings[python]>=0.24.0",
|
||||||
"mkdocs-literate-nav>=0.6.1",
|
"mkdocs-literate-nav>=0.6.1",
|
||||||
]
|
]
|
||||||
|
solutions = [
|
||||||
|
"huggingface-hub>=0.24.6",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
|
|
|
@ -100,9 +100,10 @@ gitpython==3.1.43
|
||||||
# via wandb
|
# via wandb
|
||||||
griffe==0.48.0
|
griffe==0.48.0
|
||||||
# via mkdocstrings-python
|
# via mkdocstrings-python
|
||||||
huggingface-hub==0.24.5
|
huggingface-hub==0.24.6
|
||||||
# via datasets
|
# via datasets
|
||||||
# via diffusers
|
# via diffusers
|
||||||
|
# via refiners
|
||||||
# via timm
|
# via timm
|
||||||
# via tokenizers
|
# via tokenizers
|
||||||
# via transformers
|
# via transformers
|
||||||
|
|
|
@ -466,6 +466,15 @@ def download_mvanet():
|
||||||
check_hash(dest_filename, "b915d492")
|
check_hash(dest_filename, "b915d492")
|
||||||
|
|
||||||
|
|
||||||
|
def download_box_segmenter():
|
||||||
|
download_file(
|
||||||
|
"https://huggingface.co/finegrain/finegrain-box-segmenter/resolve/v0.1/model.safetensors",
|
||||||
|
dest_folder=test_weights_dir,
|
||||||
|
filename="finegrain-box-segmenter-v0-1.safetensors",
|
||||||
|
expected_hash="e0450e8c",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def printg(msg: str):
|
def printg(msg: str):
|
||||||
"""print in green color"""
|
"""print in green color"""
|
||||||
print("\033[92m" + msg + "\033[0m")
|
print("\033[92m" + msg + "\033[0m")
|
||||||
|
@ -861,6 +870,7 @@ def download_all():
|
||||||
download_sdxl_lightning_lora()
|
download_sdxl_lightning_lora()
|
||||||
download_ic_light()
|
download_ic_light()
|
||||||
download_mvanet()
|
download_mvanet()
|
||||||
|
download_box_segmenter()
|
||||||
|
|
||||||
|
|
||||||
def convert_all():
|
def convert_all():
|
||||||
|
|
3
src/refiners/solutions/__init__.py
Normal file
3
src/refiners/solutions/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .box_segmenter import BoxSegmenter
|
||||||
|
|
||||||
|
__all__ = ["BoxSegmenter"]
|
79
src/refiners/solutions/box_segmenter.py
Normal file
79
src/refiners/solutions/box_segmenter.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image
|
||||||
|
from refiners.foundationals.swin.mvanet import MVANet
|
||||||
|
|
||||||
|
BoundingBox = tuple[int, int, int, int]
|
||||||
|
|
||||||
|
|
||||||
|
class BoxSegmenter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
margin: float = 0.05,
|
||||||
|
weights: Path | str | dict[str, torch.Tensor] | None = None,
|
||||||
|
device: torch.device | str = "cpu",
|
||||||
|
):
|
||||||
|
assert margin >= 0
|
||||||
|
self.margin = margin
|
||||||
|
|
||||||
|
self.device = torch.device(device)
|
||||||
|
self.model = MVANet(device=self.device).eval()
|
||||||
|
|
||||||
|
if weights is None:
|
||||||
|
from huggingface_hub.file_download import hf_hub_download # type: ignore[reportUnknownVariableType]
|
||||||
|
|
||||||
|
weights = hf_hub_download(
|
||||||
|
repo_id="finegrain/finegrain-box-segmenter",
|
||||||
|
filename="model.safetensors",
|
||||||
|
revision="v0.1",
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(weights, dict):
|
||||||
|
self.model.load_state_dict(weights)
|
||||||
|
else:
|
||||||
|
self.model.load_from_safetensors(weights)
|
||||||
|
|
||||||
|
def __call__(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image:
|
||||||
|
return self.run(img, box_prompt)
|
||||||
|
|
||||||
|
def add_margin(self, box: BoundingBox) -> BoundingBox:
|
||||||
|
x0, y0, x1, y1 = box
|
||||||
|
mx = int((x1 - x0) * self.margin)
|
||||||
|
my = int((y1 - y0) * self.margin)
|
||||||
|
return (x0 - mx, y0 - my, x1 + mx, y1 + my)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def crop_pad(img: Image.Image, box: BoundingBox) -> Image.Image:
|
||||||
|
img = img.convert("RGB")
|
||||||
|
|
||||||
|
x0, y0, x1, y1 = box
|
||||||
|
px0, py0, px1, py1 = (max(0, -x0), max(0, -y0), max(0, x1 - img.width), max(0, y1 - img.height))
|
||||||
|
if (px0, py0, px1, py1) == (0, 0, 0, 0):
|
||||||
|
return img.crop(box)
|
||||||
|
|
||||||
|
padded = Image.new("RGB", (img.width + px0 + px1, img.height + py0 + py1))
|
||||||
|
padded.paste(img, (px0, py0))
|
||||||
|
return padded.crop((x0 + px0, y0 + py0, x1 + px0, y1 + py0))
|
||||||
|
|
||||||
|
def predict(self, img: Image.Image) -> Image.Image:
|
||||||
|
in_t = image_to_tensor(img.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
|
||||||
|
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
|
||||||
|
with no_grad():
|
||||||
|
prediction: torch.Tensor = self.model(in_t.to(self.device)).sigmoid()
|
||||||
|
return tensor_to_image(prediction).resize(img.size, Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
def run(self, img: Image.Image, box_prompt: BoundingBox | None = None) -> Image.Image:
|
||||||
|
if box_prompt is None:
|
||||||
|
box_prompt = (0, 0, img.width, img.height)
|
||||||
|
|
||||||
|
box = self.add_margin(box_prompt)
|
||||||
|
cropped = self.crop_pad(img, box)
|
||||||
|
prediction = self.predict(cropped)
|
||||||
|
|
||||||
|
out = Image.new("L", (img.width, img.height))
|
||||||
|
out.paste(prediction, box)
|
||||||
|
return out
|
72
tests/e2e/test_solutions.py
Normal file
72
tests/e2e/test_solutions.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
from refiners.solutions import BoxSegmenter
|
||||||
|
|
||||||
|
|
||||||
|
def _img_open(path: Path) -> Image.Image:
|
||||||
|
return Image.open(path) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ref_path(test_e2e_path: Path) -> Path:
|
||||||
|
return test_e2e_path / "test_solutions_ref"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def ref_shelves(ref_path: Path) -> Image.Image:
|
||||||
|
return _img_open(ref_path / "shelves.jpg").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_box_segmenter_plant_mask(ref_path: Path) -> Image.Image:
|
||||||
|
return _img_open(ref_path / "expected_box_segmenter_plant_mask.png")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_box_segmenter_spray_mask(ref_path: Path) -> Image.Image:
|
||||||
|
return _img_open(ref_path / "expected_box_segmenter_spray_mask.png")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_box_segmenter_spray_cropped_mask(ref_path: Path) -> Image.Image:
|
||||||
|
return _img_open(ref_path / "expected_box_segmenter_spray_cropped_mask.png")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def box_segmenter_weights(test_weights_path: Path) -> Path:
|
||||||
|
weights = test_weights_path / "finegrain-box-segmenter-v0-1.safetensors"
|
||||||
|
if not weights.is_file():
|
||||||
|
warn(f"could not find weights at {test_weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def test_box_segmenter(
|
||||||
|
box_segmenter_weights: Path,
|
||||||
|
ref_shelves: Image.Image,
|
||||||
|
expected_box_segmenter_plant_mask: Image.Image,
|
||||||
|
expected_box_segmenter_spray_mask: Image.Image,
|
||||||
|
expected_box_segmenter_spray_cropped_mask: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
segmenter = BoxSegmenter(weights=box_segmenter_weights, device=test_device)
|
||||||
|
|
||||||
|
plant_mask = segmenter(ref_shelves, box_prompt=(504, 82, 754, 368))
|
||||||
|
ensure_similar_images(plant_mask.convert("RGB"), expected_box_segmenter_plant_mask.convert("RGB"))
|
||||||
|
|
||||||
|
spray_box = (461, 542, 594, 823)
|
||||||
|
spray_mask = segmenter(ref_shelves, box_prompt=spray_box)
|
||||||
|
ensure_similar_images(spray_mask.convert("RGB"), expected_box_segmenter_spray_mask.convert("RGB"))
|
||||||
|
|
||||||
|
# Test left and bottom padding.
|
||||||
|
off_l, off_b = 11, 7
|
||||||
|
shelves_cropped = ref_shelves.crop((spray_box[0] - off_l, 0, ref_shelves.width, spray_box[3] + off_b))
|
||||||
|
spray_cropped_box = (off_l, spray_box[1], spray_box[2] - spray_box[0] + off_l, spray_box[3])
|
||||||
|
spray_cropped_mask = segmenter(shelves_cropped, box_prompt=spray_cropped_box)
|
||||||
|
ensure_similar_images(spray_cropped_mask.convert("RGB"), expected_box_segmenter_spray_cropped_mask.convert("RGB"))
|
3
tests/e2e/test_solutions_ref/README.md
Normal file
3
tests/e2e/test_solutions_ref/README.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
`shelves.jpg` is found here: https://www.freepik.com/free-photo/front-view-shelves-with-plants_6446859.htm
|
||||||
|
|
||||||
|
`expected_box_segmenter_plant_mask.png`, `expected_box_segmenter_spray_mask.png` and `expected_box_segmenter_spray_cropped_mask.png` are generated with Refiners.
|
Binary file not shown.
After Width: | Height: | Size: 19 KiB |
Binary file not shown.
After Width: | Height: | Size: 6.4 KiB |
Binary file not shown.
After Width: | Height: | Size: 7.1 KiB |
BIN
tests/e2e/test_solutions_ref/shelves.jpg
Normal file
BIN
tests/e2e/test_solutions_ref/shelves.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 126 KiB |
Loading…
Reference in a new issue