SAM/HQSAMAdapter: docstring examples

This commit is contained in:
Pierre Colle 2024-04-03 10:19:15 +00:00 committed by Colle
parent e033306f60
commit d05ebb8dd3
3 changed files with 42 additions and 1 deletions

View file

@ -1,3 +1,4 @@
from refiners.foundationals.segment_anything.hq_sam import HQSAMAdapter
from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH
__all__ = ["SegmentAnything", "SegmentAnythingH"]
__all__ = ["SegmentAnything", "SegmentAnythingH", "HQSAMAdapter"]

View file

@ -291,6 +291,18 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
"""Adapter for SAM introducing HQ features.
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
Example:
```py
from refiners.fluxion.utils import load_from_safetensors
# Tips: run scripts/prepare_test_weights.py to download the weights
tensor_path = "./tests/weights/refiners-sam-hq-vit-h.safetensors"
weights = load_from_safetensors(tensor_path)
hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)
hq_sam_adapter.inject() # then use SAM as usual
```
"""
_adapter_modules: dict[str, fl.Module] = {}

View file

@ -25,6 +25,8 @@ class SegmentAnything(fl.Chain):
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
E.g. see [`SegmentAnythingH`][refiners.foundationals.segment_anything.model.SegmentAnythingH] for usage.
Attributes:
mask_threshold (float): 0.0
"""
@ -262,6 +264,32 @@ class SegmentAnythingH(SegmentAnything):
multimask_output: Whether to use multimask output.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
Example:
```py
device="cuda" if torch.cuda.is_available() else "cpu"
# multimask_output=True is recommended for ambiguous prompts such as a single point.
# Below, a box prompt is passed, so just use multimask_output=False which will return a single mask
sam_h = SegmentAnythingH(multimask_output=False, device=device)
# Tips: run scripts/prepare_test_weights.py to download the weights
tensors_path = "./tests/weights/segment-anything-h.safetensors"
sam_h.load_from_safetensors(tensors_path=tensors_path)
from PIL import Image
image = Image.open("image.png")
masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])
assert masks.shape == (1, 1, image.height, image.width)
assert masks.dtype == torch.bool
# convert it to [0,255] uint8 ndarray of shape (H, W)
mask = masks[0, 0].cpu().numpy().astype("uint8") * 255
Image.fromarray(mask).save("mask_image.png")
```
"""
image_encoder = image_encoder or SAMViTH()
point_encoder = point_encoder or PointEncoder()