mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
HQ-SAM logit equal test, following #331
This commit is contained in:
parent
2763db960e
commit
5c937b184a
|
@ -144,7 +144,7 @@ exclude_also = [
|
|||
|
||||
[tool.typos.default]
|
||||
extend-words = { adaptee = "adaptee" }
|
||||
extend-ignore-identifiers-re = ["NDArray*"]
|
||||
extend-ignore-identifiers-re = ["NDArray*", "interm"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = [
|
||||
|
|
|
@ -23,7 +23,7 @@ from refiners.foundationals.segment_anything.hq_sam import (
|
|||
MaskDecoderTokensExtender,
|
||||
PredictionsPostProc,
|
||||
)
|
||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -248,8 +248,8 @@ def test_predictor(
|
|||
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore
|
||||
|
||||
# NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3)
|
||||
# See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426
|
||||
# NOTE: Diff on logits is relatively high,
|
||||
# see test_predictor_equal for a stricter version
|
||||
assert torch.allclose(
|
||||
reference_low_res_mask_hq,
|
||||
refiners_low_res_mask_hq,
|
||||
|
@ -265,6 +265,58 @@ def test_predictor(
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hq_mask_only", [True, False])
|
||||
def test_predictor_equal(
|
||||
sam_h: SegmentAnythingH,
|
||||
hq_adapter_weights: Path,
|
||||
hq_mask_only: bool,
|
||||
reference_sam_h_predictor: FacebookSAMPredictorHQ,
|
||||
tennis: Image.Image,
|
||||
one_prompt: SAMPrompt,
|
||||
) -> None:
|
||||
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
||||
adapter.hq_mask_only = hq_mask_only
|
||||
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
|
||||
|
||||
# See in test_sam.py test_predictor_resized_single_output
|
||||
# to do torch.equal we need to resize the image before
|
||||
# and to use image_embedding as input
|
||||
|
||||
size = (1024, 1024)
|
||||
resized_tennis = tennis.resize(size)
|
||||
|
||||
# Reference
|
||||
reference_sam_h_predictor.set_image(np.array(resized_tennis))
|
||||
|
||||
predictor_prompt = one_prompt.__dict__["box_points"]
|
||||
masks_np, _, low_res_masks_np = reference_sam_h_predictor.predict(
|
||||
box=np.array(predictor_prompt).flatten(),
|
||||
multimask_output=False,
|
||||
hq_token_only=hq_mask_only,
|
||||
)
|
||||
|
||||
reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||
|
||||
# Refiners
|
||||
|
||||
# We bypass the refiners ViT by using directly the image features and interm_features
|
||||
# from the reference implementation: this gives the ability to do torch.equal
|
||||
reference_image_embedding = ImageEmbedding(features=reference_sam_h_predictor.features, original_image_size=size)
|
||||
adapter.set_context("hq_sam", {"early_vit_embedding": reference_sam_h_predictor.interm_features[0]})
|
||||
|
||||
high_res_masks, _, low_res_masks = sam_h.predict(reference_image_embedding, **one_prompt.__dict__)
|
||||
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
|
||||
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
|
||||
|
||||
assert torch.equal(
|
||||
reference_low_res_mask_hq,
|
||||
refiners_low_res_mask_hq,
|
||||
)
|
||||
assert torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() == 0
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
|
||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||
|
|
|
@ -57,6 +57,8 @@ class FacebookSAMPredictor:
|
|||
|
||||
class FacebookSAMPredictorHQ:
|
||||
model: FacebookSAM
|
||||
features: Tensor
|
||||
interm_features: Tensor
|
||||
|
||||
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
||||
|
||||
|
|
Loading…
Reference in a new issue