Skip to content

Segment Anything

HQSAMAdapter

HQSAMAdapter(
    target: SegmentAnything,
    hq_mask_only: bool = False,
    weights: dict[str, Tensor] | None = None,
)

Bases: Chain, Adapter[SegmentAnything]

Adapter for SAM introducing HQ features.

See [arXiv:2306.01567] Segment Anything in High Quality for details.

Example
from refiners.fluxion.utils import load_from_safetensors

# Tips: run scripts/prepare_test_weights.py to download the weights
tensor_path = "./tests/weights/refiners-sam-hq-vit-h.safetensors"
weights = load_from_safetensors(tensor_path)

hq_sam_adapter = HQSAMAdapter(sam_h, weights=weights)
hq_sam_adapter.inject()  # then use SAM as usual

Parameters:

Name Type Description Default
target SegmentAnything

The SegmentAnything model to adapt.

required
hq_mask_only bool

Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).

False
weights dict[str, Tensor] | None

The weights of the HQSAMAdapter.

None
Source code in src/refiners/foundationals/segment_anything/hq_sam.py
def __init__(
    self,
    target: SegmentAnything,
    hq_mask_only: bool = False,
    weights: dict[str, torch.Tensor] | None = None,
) -> None:
    """Initialize the adapter.

    Args:
        target: The SegmentAnything model to adapt.
        hq_mask_only: Whether to output only the high-quality mask or use it for mask correction (by summing it with the base SAM mask).
        weights: The weights of the HQSAMAdapter.
    """
    self.vit_embedding_dim = target.image_encoder.embedding_dim
    self.target_num_mask_tokens = target.mask_decoder.num_multimask_outputs + 2

    with self.setup_adapter(target):
        super().__init__(target)

    if target.mask_decoder.multimask_output:
        raise NotImplementedError("Multi-mask mode is not supported in HQSAMAdapter.")

    mask_prediction = target.mask_decoder.ensure_find(MaskPrediction)

    self._mask_prediction_adapter = [
        MaskPredictionAdapter(
            mask_prediction, self.vit_embedding_dim, self.target_num_mask_tokens, target.device, target.dtype
        )
    ]
    self._register_adapter_module("Chain.HQSAMMaskPrediction", self.mask_prediction_adapter.hq_sam_mask_prediction)

    self._image_encoder_adapter = [SAMViTAdapter(target.image_encoder)]
    self._predictions_post_proc = [PredictionsPostProc(hq_mask_only)]

    mask_decoder_tokens = target.mask_decoder.ensure_find(MaskDecoderTokens)
    self._mask_decoder_tokens_extender = [MaskDecoderTokensExtender(mask_decoder_tokens)]
    self._register_adapter_module("MaskDecoderTokensExtender.hq_token", self.mask_decoder_tokens_extender.hq_token)

    if weights is not None:
        self.load_weights(weights)

    self.to(device=target.device, dtype=target.dtype)

SegmentAnything

SegmentAnything(
    image_encoder: SAMViT,
    point_encoder: PointEncoder,
    mask_encoder: MaskEncoder,
    mask_decoder: MaskDecoder,
    device: device | str = "cpu",
    dtype: dtype = float32,
)

Bases: Chain

SegmentAnything model.

See [arXiv:2304.02643] Segment Anything

E.g. see SegmentAnythingH for usage.

Attributes:

Name Type Description
mask_threshold float

0.0

Parameters:

Name Type Description Default
image_encoder SAMViT

The image encoder to use.

required
point_encoder PointEncoder

The point encoder to use.

required
mask_encoder MaskEncoder

The mask encoder to use.

required
mask_decoder MaskDecoder

The mask decoder to use.

required
Source code in src/refiners/foundationals/segment_anything/model.py
def __init__(
    self,
    image_encoder: SAMViT,
    point_encoder: PointEncoder,
    mask_encoder: MaskEncoder,
    mask_decoder: MaskDecoder,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
) -> None:
    """Initialize SegmentAnything model.

    Args:
        image_encoder: The image encoder to use.
        point_encoder: The point encoder to use.
        mask_encoder: The mask encoder to use.
        mask_decoder: The mask decoder to use.
    """
    super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)

    self.to(device=device, dtype=dtype)

image_encoder property

image_encoder: SAMViT

The image encoder.

image_encoder_resolution property

image_encoder_resolution: int

The resolution of the image encoder.

mask_decoder property

mask_decoder: MaskDecoder

The mask decoder.

mask_encoder property

mask_encoder: MaskEncoder

The mask encoder.

point_encoder property

point_encoder: PointEncoder

The point encoder.

compute_image_embedding

compute_image_embedding(image: Image) -> ImageEmbedding

Compute the embedding of an image.

Parameters:

Name Type Description Default
image Image

The image to compute the embedding of.

required

Returns:

Type Description
ImageEmbedding

The computed image embedding.

Source code in src/refiners/foundationals/segment_anything/model.py
@no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
    """Compute the embedding of an image.

    Args:
        image: The image to compute the embedding of.

    Returns:
        The computed image embedding.
    """
    original_size = (image.height, image.width)
    return ImageEmbedding(
        features=self.image_encoder(self.preprocess_image(image)),
        original_image_size=original_size,
    )

normalize

normalize(
    coordinates: Tensor, original_size: tuple[int, int]
) -> Tensor

See normalize_coordinates Args: coordinates: a tensor of coordinates. original_size: (h, w) the original size of the image. Returns: The [0,1] normalized coordinates tensor.

Source code in src/refiners/foundationals/segment_anything/model.py
def normalize(self, coordinates: Tensor, original_size: tuple[int, int]) -> Tensor:
    """
    See [`normalize_coordinates`][refiners.foundationals.segment_anything.utils.normalize_coordinates]
    Args:
        coordinates: a tensor of coordinates.
        original_size: (h, w) the original size of the image.
    Returns:
        The [0,1] normalized coordinates tensor.
    """
    return normalize_coordinates(coordinates, original_size, self.image_encoder_resolution)

postprocess_masks

postprocess_masks(
    low_res_masks: Tensor, original_size: tuple[int, int]
) -> Tensor

See postprocess_masks Args: low_res_masks: a mask tensor of size (N, 1, 256, 256) original_size: (h, w) the original size of the image. Returns: The mask of shape (N, 1, H, W)

Source code in src/refiners/foundationals/segment_anything/model.py
def postprocess_masks(self, low_res_masks: Tensor, original_size: tuple[int, int]) -> Tensor:
    """
    See [`postprocess_masks`][refiners.foundationals.segment_anything.utils.postprocess_masks]
    Args:
        low_res_masks: a mask tensor of size (N, 1, 256, 256)
        original_size: (h, w) the original size of the image.
    Returns:
        The mask of shape (N, 1, H, W)
    """
    return postprocess_masks(low_res_masks, original_size, self.image_encoder_resolution)

predict

predict(
    input: Image | ImageEmbedding,
    foreground_points: (
        Sequence[tuple[float, float]] | None
    ) = None,
    background_points: (
        Sequence[tuple[float, float]] | None
    ) = None,
    box_points: (
        Sequence[Sequence[tuple[float, float]]] | None
    ) = None,
    low_res_mask: (
        Float[Tensor, "1 1 256 256"] | None
    ) = None,
    binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]

Predict the masks of the input image.

Parameters:

Name Type Description Default
input Image | ImageEmbedding

The input image or its embedding.

required
foreground_points Sequence[tuple[float, float]] | None

The points of the foreground.

None
background_points Sequence[tuple[float, float]] | None

The points of the background.

None
box_points Sequence[Sequence[tuple[float, float]]] | None

The points of the box.

None
low_res_mask Float[Tensor, '1 1 256 256'] | None

The low resolution mask.

None
binarize bool

Whether to binarize the masks.

True

Returns:

Type Description
Tensor

The predicted masks.

Tensor

The IOU prediction.

Tensor

The low resolution masks.

Source code in src/refiners/foundationals/segment_anything/model.py
@no_grad()
def predict(
    self,
    input: Image.Image | ImageEmbedding,
    foreground_points: Sequence[tuple[float, float]] | None = None,
    background_points: Sequence[tuple[float, float]] | None = None,
    box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
    low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
    binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]:
    """Predict the masks of the input image.

    Args:
        input: The input image or its embedding.
        foreground_points: The points of the foreground.
        background_points: The points of the background.
        box_points: The points of the box.
        low_res_mask: The low resolution mask.
        binarize: Whether to binarize the masks.

    Returns:
        The predicted masks.
        The IOU prediction.
        The low resolution masks.
    """
    if isinstance(input, ImageEmbedding):
        original_size = input.original_image_size
        image_embedding = input.features
    else:
        original_size = (input.height, input.width)
        image_embedding = self.image_encoder(self.preprocess_image(input))

    coordinates, type_mask = self.point_encoder.points_to_tensor(
        foreground_points=foreground_points,
        background_points=background_points,
        box_points=box_points,
    )
    self.point_encoder.set_type_mask(type_mask=type_mask)

    if low_res_mask is not None:
        mask_embedding = self.mask_encoder(low_res_mask)
    else:
        mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
            image_embedding_size=self.image_encoder.image_embedding_size
        )

    point_embedding = self.point_encoder(self.normalize(coordinates, original_size=original_size))
    dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
        image_embedding_size=self.image_encoder.image_embedding_size
    )

    self.mask_decoder.set_image_embedding(image_embedding=image_embedding)
    self.mask_decoder.set_mask_embedding(mask_embedding=mask_embedding)
    self.mask_decoder.set_point_embedding(point_embedding=point_embedding)
    self.mask_decoder.set_dense_positional_embedding(dense_positional_embedding=dense_positional_embedding)

    low_res_masks, iou_predictions = self.mask_decoder()

    high_res_masks = self.postprocess_masks(low_res_masks, original_size)

    if binarize:
        high_res_masks = high_res_masks > self.mask_threshold

    return high_res_masks, iou_predictions, low_res_masks

preprocess_image

preprocess_image(image: Image) -> Tensor

See preprocess_image Args: image: The image to preprocess. Returns: The preprocessed tensor.

Source code in src/refiners/foundationals/segment_anything/model.py
def preprocess_image(self, image: Image.Image) -> Tensor:
    """
    See [`preprocess_image`][refiners.foundationals.segment_anything.utils.preprocess_image]
    Args:
        image: The image to preprocess.
    Returns:
        The preprocessed tensor.
    """
    return preprocess_image(image, self.image_encoder_resolution, self.device, self.dtype)

SegmentAnythingH

SegmentAnythingH(
    image_encoder: SAMViTH | None = None,
    point_encoder: PointEncoder | None = None,
    mask_encoder: MaskEncoder | None = None,
    mask_decoder: MaskDecoder | None = None,
    multimask_output: bool | None = None,
    device: device | str = "cpu",
    dtype: dtype = float32,
)

Bases: SegmentAnything

SegmentAnything huge model.

Parameters:

Name Type Description Default
image_encoder SAMViTH | None

The image encoder to use.

None
point_encoder PointEncoder | None

The point encoder to use.

None
mask_encoder MaskEncoder | None

The mask encoder to use.

None
mask_decoder MaskDecoder | None

The mask decoder to use.

None
multimask_output bool | None

Whether to use multimask output.

None
device device | str

The PyTorch device to use.

'cpu'
dtype dtype

The PyTorch data type to use.

float32
Example
device="cuda" if torch.cuda.is_available() else "cpu"

# multimask_output=True is recommended for ambiguous prompts such as a single point.
# Below, a box prompt is passed, so just use multimask_output=False which will return a single mask
sam_h = SegmentAnythingH(multimask_output=False, device=device)

# Tips: run scripts/prepare_test_weights.py to download the weights
tensors_path = "./tests/weights/segment-anything-h.safetensors"
sam_h.load_from_safetensors(tensors_path=tensors_path)

from PIL import Image
image = Image.open("image.png")

masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])

assert masks.shape == (1, 1, image.height, image.width)
assert masks.dtype == torch.bool

# convert it to [0,255] uint8 ndarray of shape (H, W)
mask = masks[0, 0].cpu().numpy().astype("uint8") * 255

Image.fromarray(mask).save("mask_image.png")
Source code in src/refiners/foundationals/segment_anything/model.py
def __init__(
    self,
    image_encoder: SAMViTH | None = None,
    point_encoder: PointEncoder | None = None,
    mask_encoder: MaskEncoder | None = None,
    mask_decoder: MaskDecoder | None = None,
    multimask_output: bool | None = None,
    device: Device | str = "cpu",
    dtype: DType = torch.float32,
) -> None:
    """Initialize SegmentAnything huge model.

    Args:
        image_encoder: The image encoder to use.
        point_encoder: The point encoder to use.
        mask_encoder: The mask encoder to use.
        mask_decoder: The mask decoder to use.
        multimask_output: Whether to use multimask output.
        device: The PyTorch device to use.
        dtype: The PyTorch data type to use.

    Example:
        ```py
        device="cuda" if torch.cuda.is_available() else "cpu"

        # multimask_output=True is recommended for ambiguous prompts such as a single point.
        # Below, a box prompt is passed, so just use multimask_output=False which will return a single mask
        sam_h = SegmentAnythingH(multimask_output=False, device=device)

        # Tips: run scripts/prepare_test_weights.py to download the weights
        tensors_path = "./tests/weights/segment-anything-h.safetensors"
        sam_h.load_from_safetensors(tensors_path=tensors_path)

        from PIL import Image
        image = Image.open("image.png")

        masks, *_ = sam_h.predict(image, box_points=[[(x1, y1), (x2, y2)]])

        assert masks.shape == (1, 1, image.height, image.width)
        assert masks.dtype == torch.bool

        # convert it to [0,255] uint8 ndarray of shape (H, W)
        mask = masks[0, 0].cpu().numpy().astype("uint8") * 255

        Image.fromarray(mask).save("mask_image.png")
        ```
    """
    image_encoder = image_encoder or SAMViTH()
    point_encoder = point_encoder or PointEncoder()
    mask_encoder = mask_encoder or MaskEncoder()

    if mask_decoder:
        assert (
            multimask_output is None or mask_decoder.multimask_output == multimask_output
        ), f"mask_decoder.multimask_output {mask_decoder.multimask_output} should match multimask_output ({multimask_output})"
    else:
        mask_decoder = MaskDecoder(multimask_output) if multimask_output is not None else MaskDecoder()

    super().__init__(image_encoder, point_encoder, mask_encoder, mask_decoder)

    self.to(device=device, dtype=dtype)

image_encoder property

image_encoder: SAMViTH

The image encoder.

compute_scaled_size

compute_scaled_size(
    size: tuple[int, int], image_encoder_resolution: int
) -> tuple[int, int]

Compute the scaled size as expected by the image encoder. This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.

Parameters:

Name Type Description Default
size tuple[int, int]

The size (h, w) of the input image.

required
image_encoder_resolution int

Image encoder resolution.

required

Returns:

Type Description
int

The target height.

int

The target width.

Source code in src/refiners/foundationals/segment_anything/utils.py
def compute_scaled_size(size: tuple[int, int], image_encoder_resolution: int) -> tuple[int, int]:
    """Compute the scaled size as expected by the image encoder.
    This computed size keep the ratio of the input image, and scale it to fit inside the square (image_encoder_resolution, image_encoder_resolution) of image encoder.

    Args:
        size: The size (h, w) of the input image.
        image_encoder_resolution: Image encoder resolution.

    Returns:
        The target height.
        The target width.
    """
    oldh, oldw = size
    scale = image_encoder_resolution * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    neww = int(neww + 0.5)
    newh = int(newh + 0.5)
    return (newh, neww)

image_to_scaled_tensor

image_to_scaled_tensor(
    image: Image,
    scaled_size: tuple[int, int],
    device: device | None = None,
    dtype: dtype | None = None,
) -> Tensor

Resize the image to scaled_size and convert it to a tensor.

Parameters:

Name Type Description Default
image Image

The image.

required
scaled_size tuple[int, int]

The target size (h, w).

required
device device | None

Tensor device.

None
dtype dtype | None

Tensor dtype.

None

Returns: a Tensor of shape (1, c, h, w)

Source code in src/refiners/foundationals/segment_anything/utils.py
def image_to_scaled_tensor(
    image: Image.Image, scaled_size: tuple[int, int], device: Device | None = None, dtype: DType | None = None
) -> Tensor:
    """Resize the image to `scaled_size` and convert it to a tensor.

    Args:
        image: The image.
        scaled_size: The target size (h, w).
        device: Tensor device.
        dtype: Tensor dtype.
    Returns:
        a Tensor of shape (1, c, h, w)
    """
    h, w = scaled_size
    resized = image.resize((w, h), resample=Image.Resampling.BILINEAR)  # type: ignore

    return image_to_tensor(resized, device=device, dtype=dtype) * 255.0

normalize_coordinates

normalize_coordinates(
    coordinates: Tensor,
    original_size: tuple[int, int],
    image_encoder_resolution: int,
) -> Tensor

Normalize the coordinates in the [0,1] range

Parameters:

Name Type Description Default
coordinates Tensor

The coordinates to normalize.

required
original_size tuple[int, int]

The original image size.

required
image_encoder_resolution int

Image encoder resolution.

required

Returns:

Type Description
Tensor

The normalized coordinates.

Source code in src/refiners/foundationals/segment_anything/utils.py
def normalize_coordinates(coordinates: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
    """Normalize the coordinates in the [0,1] range

    Args:
        coordinates: The coordinates to normalize.
        original_size: The original image size.
        image_encoder_resolution: Image encoder resolution.

    Returns:
        The normalized coordinates.
    """
    scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
    coordinates[:, :, 0] = (
        (coordinates[:, :, 0] * (scaled_size[1] / original_size[1])) + 0.5
    ) / image_encoder_resolution
    coordinates[:, :, 1] = (
        (coordinates[:, :, 1] * (scaled_size[0] / original_size[0])) + 0.5
    ) / image_encoder_resolution
    return coordinates

pad_image_tensor

pad_image_tensor(
    image_tensor: Tensor,
    scaled_size: tuple[int, int],
    image_encoder_resolution: int,
) -> Tensor

Pad an image with zeros to make it square.

Parameters:

Name Type Description Default
image_tensor Tensor

The image tensor to pad.

required
scaled_size tuple[int, int]

The scaled size (h, w).

required
image_encoder_resolution int

Image encoder resolution.

required

Returns:

Type Description
Tensor

The padded image.

Source code in src/refiners/foundationals/segment_anything/utils.py
def pad_image_tensor(image_tensor: Tensor, scaled_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
    """Pad an image with zeros to make it square.

    Args:
        image_tensor: The image tensor to pad.
        scaled_size: The scaled size (h, w).
        image_encoder_resolution: Image encoder resolution.

    Returns:
        The padded image.
    """
    assert len(image_tensor.shape) == 4
    assert image_tensor.shape[2] <= image_encoder_resolution
    assert image_tensor.shape[3] <= image_encoder_resolution

    h, w = scaled_size
    padh = image_encoder_resolution - h
    padw = image_encoder_resolution - w
    return pad(image_tensor, (0, padw, 0, padh))

postprocess_masks

postprocess_masks(
    low_res_masks: Tensor,
    original_size: tuple[int, int],
    image_encoder_resolution: int,
) -> Tensor

Postprocess the masks to fit the original image size and remove zero-padding (if any).

Parameters:

Name Type Description Default
low_res_masks Tensor

The masks to postprocess.

required
original_size tuple[int, int]

The original size (h, w).

required
image_encoder_resolution int

Image encoder resolution.

required

Returns:

Type Description
Tensor

The postprocessed masks.

Source code in src/refiners/foundationals/segment_anything/utils.py
def postprocess_masks(low_res_masks: Tensor, original_size: tuple[int, int], image_encoder_resolution: int) -> Tensor:
    """Postprocess the masks to fit the original image size and remove zero-padding (if any).

    Args:
        low_res_masks: The masks to postprocess.
        original_size: The original size (h, w).
        image_encoder_resolution: Image encoder resolution.

    Returns:
        The postprocessed masks.
    """
    scaled_size = compute_scaled_size(original_size, image_encoder_resolution)
    masks = interpolate(low_res_masks, size=Size((image_encoder_resolution, image_encoder_resolution)), mode="bilinear")
    masks = masks[..., : scaled_size[0], : scaled_size[1]]  # remove padding added at `preprocess_image` time
    masks = interpolate(masks, size=Size(original_size), mode="bilinear")
    return masks

preprocess_image

preprocess_image(
    image: Image,
    image_encoder_resolution: int,
    device: device | None = None,
    dtype: dtype | None = None,
) -> Tensor

Preprocess an image without distorting its aspect ratio.

Parameters:

Name Type Description Default
image Image

The image to preprocess before calling the image encoder.

required
image_encoder_resolution int

Image encoder resolution.

required
device device | None

Tensor device (None by default).

None
dtype dtype | None

Tensor dtype (None by default).

None

Returns:

Type Description
Tensor

The preprocessed image.

Source code in src/refiners/foundationals/segment_anything/utils.py
def preprocess_image(
    image: Image.Image, image_encoder_resolution: int, device: Device | None = None, dtype: DType | None = None
) -> Tensor:
    """Preprocess an image without distorting its aspect ratio.

    Args:
        image: The image to preprocess before calling the image encoder.
        image_encoder_resolution: Image encoder resolution.
        device: Tensor device (None by default).
        dtype: Tensor dtype (None by default).

    Returns:
        The preprocessed image.
    """

    scaled_size = compute_scaled_size((image.height, image.width), image_encoder_resolution)

    image_tensor = image_to_scaled_tensor(image, scaled_size, device=device, dtype=dtype)

    return pad_image_tensor(
        normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
        scaled_size,
        image_encoder_resolution,
    )