HQ-SAM logit equal test, following #331

This commit is contained in:
Pierre Colle 2024-03-22 21:30:46 +00:00
parent 2763db960e
commit 81ed659292
3 changed files with 58 additions and 4 deletions

View file

@ -144,7 +144,7 @@ exclude_also = [
[tool.typos.default] [tool.typos.default]
extend-words = { adaptee = "adaptee" } extend-words = { adaptee = "adaptee" }
extend-ignore-identifiers-re = ["NDArray*"] extend-ignore-identifiers-re = ["NDArray*", "interm"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
filterwarnings = [ filterwarnings = [

View file

@ -23,7 +23,7 @@ from refiners.foundationals.segment_anything.hq_sam import (
MaskDecoderTokensExtender, MaskDecoderTokensExtender,
PredictionsPostProc, PredictionsPostProc,
) )
from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -248,8 +248,8 @@ def test_predictor(
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore
# NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3) # NOTE: Diff on logits is relatively high,
# See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426 # see test_predictor_equal for a stricter version
assert torch.allclose( assert torch.allclose(
reference_low_res_mask_hq, reference_low_res_mask_hq,
refiners_low_res_mask_hq, refiners_low_res_mask_hq,
@ -265,6 +265,58 @@ def test_predictor(
) )
@pytest.mark.parametrize("hq_mask_only", [True, False])
def test_predictor_equal(
sam_h: SegmentAnythingH,
hq_adapter_weights: Path,
hq_mask_only: bool,
reference_sam_h_predictor: FacebookSAMPredictorHQ,
tennis: Image.Image,
one_prompt: SAMPrompt,
) -> None:
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
adapter.hq_mask_only = hq_mask_only
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
# See in test_sam.py test_predictor_resized_single_output
# to do torch.equal we need to resize the image before
# and to use image_embedding as input
size = (1024, 1024)
resized_tennis = tennis.resize(size)
# Reference
reference_sam_h_predictor.set_image(np.array(resized_tennis))
predictor_prompt = one_prompt.__dict__["box_points"]
masks_np, _, low_res_masks_np = reference_sam_h_predictor.predict(
box=np.array(predictor_prompt).flatten(),
multimask_output=False,
hq_token_only=hq_mask_only,
)
reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
# Refiners
# We bypass the refiners ViT by using directly the image features and interm_features
# from the reference implementation: this gives the ability to do torch.equal
reference_image_embedding = ImageEmbedding(features=reference_sam_h_predictor.features, original_image_size=size)
adapter.set_context("hq_sam", {"early_vit_embedding": reference_sam_h_predictor.interm_features[0]})
high_res_masks, _, low_res_masks = sam_h.predict(reference_image_embedding, **one_prompt.__dict__)
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
assert torch.equal(
reference_low_res_mask_hq,
refiners_low_res_mask_hq,
)
assert torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() == 0
@no_grad() @no_grad()
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None: def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject() HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()

View file

@ -57,6 +57,8 @@ class FacebookSAMPredictor:
class FacebookSAMPredictorHQ: class FacebookSAMPredictorHQ:
model: FacebookSAM model: FacebookSAM
features: Tensor
interm_features: Tensor
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ... def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...