diff --git a/src/refiners/foundationals/segment_anything/__init__.py b/src/refiners/foundationals/segment_anything/__init__.py index 5f1d54d..8ce3045 100644 --- a/src/refiners/foundationals/segment_anything/__init__.py +++ b/src/refiners/foundationals/segment_anything/__init__.py @@ -1,3 +1,4 @@ +from refiners.foundationals.segment_anything.hq_sam import HQSAMAdapter from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH -__all__ = ["SegmentAnything", "SegmentAnythingH"] +__all__ = ["SegmentAnything", "SegmentAnythingH", "HQSAMAdapter"] diff --git a/src/refiners/foundationals/segment_anything/hq_sam.py b/src/refiners/foundationals/segment_anything/hq_sam.py index f87e1e0..87129f5 100644 --- a/src/refiners/foundationals/segment_anything/hq_sam.py +++ b/src/refiners/foundationals/segment_anything/hq_sam.py @@ -291,6 +291,18 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]): """Adapter for SAM introducing HQ features. See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details. + + Example: + ```py + from refiners.fluxion.utils import load_from_safetensors + + # Tips: run scripts/prepare_test_weights.py to download the weights + tensor_path = "./tests/weights/refiners-sam-hq-vit-h.safetensors" + weights = load_from_safetensors(tensor_path) + + hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights) + hq_sam_adapter.inject() # then use SAM as usual + ``` """ _adapter_modules: dict[str, fl.Module] = {} diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 48126f6..1f94e70 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -25,6 +25,8 @@ class SegmentAnything(fl.Chain): See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643) + E.g. see [`SegmentAnythingH`][refiners.foundationals.segment_anything.model.SegmentAnythingH] for usage. + Attributes: mask_threshold (float): 0.0 """ @@ -262,6 +264,32 @@ class SegmentAnythingH(SegmentAnything): multimask_output: Whether to use multimask output. device: The PyTorch device to use. dtype: The PyTorch data type to use. + + Example: + ```py + device="cuda" if torch.cuda.is_available() else "cpu" + + # multimask_output=True is recommended for ambiguous prompts such as a single point. + # Below, a box prompt is passed, so just use multimask_output=False which will return a single mask + sam_h = SegmentAnythingH(multimask_output=False, device=device) + + # Tips: run scripts/prepare_test_weights.py to download the weights + tensors_path = "./tests/weights/segment-anything-h.safetensors" + sam_h.load_from_safetensors(tensors_path=tensors_path) + + from PIL import Image + image = Image.open("image.png") + + masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]]) + + assert masks.shape == (1, 1, image.height, image.width) + assert masks.dtype == torch.bool + + # convert it to [0,255] uint8 ndarray of shape (H, W) + mask = masks[0, 0].cpu().numpy().astype("uint8") * 255 + + Image.fromarray(mask).save("mask_image.png") + ``` """ image_encoder = image_encoder or SAMViTH() point_encoder = point_encoder or PointEncoder()