mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
SAM/HQSAMAdapter: docstring examples
This commit is contained in:
parent
e033306f60
commit
d05ebb8dd3
|
@ -1,3 +1,4 @@
|
||||||
|
from refiners.foundationals.segment_anything.hq_sam import HQSAMAdapter
|
||||||
from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH
|
from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH
|
||||||
|
|
||||||
__all__ = ["SegmentAnything", "SegmentAnythingH"]
|
__all__ = ["SegmentAnything", "SegmentAnythingH", "HQSAMAdapter"]
|
||||||
|
|
|
@ -291,6 +291,18 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
||||||
"""Adapter for SAM introducing HQ features.
|
"""Adapter for SAM introducing HQ features.
|
||||||
|
|
||||||
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
|
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
from refiners.fluxion.utils import load_from_safetensors
|
||||||
|
|
||||||
|
# Tips: run scripts/prepare_test_weights.py to download the weights
|
||||||
|
tensor_path = "./tests/weights/refiners-sam-hq-vit-h.safetensors"
|
||||||
|
weights = load_from_safetensors(tensor_path)
|
||||||
|
|
||||||
|
hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)
|
||||||
|
hq_sam_adapter.inject() # then use SAM as usual
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_adapter_modules: dict[str, fl.Module] = {}
|
_adapter_modules: dict[str, fl.Module] = {}
|
||||||
|
|
|
@ -25,6 +25,8 @@ class SegmentAnything(fl.Chain):
|
||||||
|
|
||||||
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
|
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
|
||||||
|
|
||||||
|
E.g. see [`SegmentAnythingH`][refiners.foundationals.segment_anything.model.SegmentAnythingH] for usage.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
mask_threshold (float): 0.0
|
mask_threshold (float): 0.0
|
||||||
"""
|
"""
|
||||||
|
@ -262,6 +264,32 @@ class SegmentAnythingH(SegmentAnything):
|
||||||
multimask_output: Whether to use multimask output.
|
multimask_output: Whether to use multimask output.
|
||||||
device: The PyTorch device to use.
|
device: The PyTorch device to use.
|
||||||
dtype: The PyTorch data type to use.
|
dtype: The PyTorch data type to use.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```py
|
||||||
|
device="cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
# multimask_output=True is recommended for ambiguous prompts such as a single point.
|
||||||
|
# Below, a box prompt is passed, so just use multimask_output=False which will return a single mask
|
||||||
|
sam_h = SegmentAnythingH(multimask_output=False, device=device)
|
||||||
|
|
||||||
|
# Tips: run scripts/prepare_test_weights.py to download the weights
|
||||||
|
tensors_path = "./tests/weights/segment-anything-h.safetensors"
|
||||||
|
sam_h.load_from_safetensors(tensors_path=tensors_path)
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
image = Image.open("image.png")
|
||||||
|
|
||||||
|
masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])
|
||||||
|
|
||||||
|
assert masks.shape == (1, 1, image.height, image.width)
|
||||||
|
assert masks.dtype == torch.bool
|
||||||
|
|
||||||
|
# convert it to [0,255] uint8 ndarray of shape (H, W)
|
||||||
|
mask = masks[0, 0].cpu().numpy().astype("uint8") * 255
|
||||||
|
|
||||||
|
Image.fromarray(mask).save("mask_image.png")
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
image_encoder = image_encoder or SAMViTH()
|
image_encoder = image_encoder or SAMViTH()
|
||||||
point_encoder = point_encoder or PointEncoder()
|
point_encoder = point_encoder or PointEncoder()
|
||||||
|
|
Loading…
Reference in a new issue