mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
SAM/HQSAMAdapter: docstring examples
This commit is contained in:
parent
e033306f60
commit
d05ebb8dd3
|
@ -1,3 +1,4 @@
|
|||
from refiners.foundationals.segment_anything.hq_sam import HQSAMAdapter
|
||||
from refiners.foundationals.segment_anything.model import SegmentAnything, SegmentAnythingH
|
||||
|
||||
__all__ = ["SegmentAnything", "SegmentAnythingH"]
|
||||
__all__ = ["SegmentAnything", "SegmentAnythingH", "HQSAMAdapter"]
|
||||
|
|
|
@ -291,6 +291,18 @@ class HQSAMAdapter(fl.Chain, Adapter[SegmentAnything]):
|
|||
"""Adapter for SAM introducing HQ features.
|
||||
|
||||
See [[arXiv:2306.01567] Segment Anything in High Quality](https://arxiv.org/abs/2306.01567) for details.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from refiners.fluxion.utils import load_from_safetensors
|
||||
|
||||
# Tips: run scripts/prepare_test_weights.py to download the weights
|
||||
tensor_path = "./tests/weights/refiners-sam-hq-vit-h.safetensors"
|
||||
weights = load_from_safetensors(tensor_path)
|
||||
|
||||
hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)
|
||||
hq_sam_adapter.inject() # then use SAM as usual
|
||||
```
|
||||
"""
|
||||
|
||||
_adapter_modules: dict[str, fl.Module] = {}
|
||||
|
|
|
@ -25,6 +25,8 @@ class SegmentAnything(fl.Chain):
|
|||
|
||||
See [[arXiv:2304.02643] Segment Anything](https://arxiv.org/abs/2304.02643)
|
||||
|
||||
E.g. see [`SegmentAnythingH`][refiners.foundationals.segment_anything.model.SegmentAnythingH] for usage.
|
||||
|
||||
Attributes:
|
||||
mask_threshold (float): 0.0
|
||||
"""
|
||||
|
@ -262,6 +264,32 @@ class SegmentAnythingH(SegmentAnything):
|
|||
multimask_output: Whether to use multimask output.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
|
||||
Example:
|
||||
```py
|
||||
device="cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# multimask_output=True is recommended for ambiguous prompts such as a single point.
|
||||
# Below, a box prompt is passed, so just use multimask_output=False which will return a single mask
|
||||
sam_h = SegmentAnythingH(multimask_output=False, device=device)
|
||||
|
||||
# Tips: run scripts/prepare_test_weights.py to download the weights
|
||||
tensors_path = "./tests/weights/segment-anything-h.safetensors"
|
||||
sam_h.load_from_safetensors(tensors_path=tensors_path)
|
||||
|
||||
from PIL import Image
|
||||
image = Image.open("image.png")
|
||||
|
||||
masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])
|
||||
|
||||
assert masks.shape == (1, 1, image.height, image.width)
|
||||
assert masks.dtype == torch.bool
|
||||
|
||||
# convert it to [0,255] uint8 ndarray of shape (H, W)
|
||||
mask = masks[0, 0].cpu().numpy().astype("uint8") * 255
|
||||
|
||||
Image.fromarray(mask).save("mask_image.png")
|
||||
```
|
||||
"""
|
||||
image_encoder = image_encoder or SAMViTH()
|
||||
point_encoder = point_encoder or PointEncoder()
|
||||
|
|
Loading…
Reference in a new issue