mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
HQ-SAM logit equal test, following #331
This commit is contained in:
parent
2763db960e
commit
5c937b184a
|
@ -144,7 +144,7 @@ exclude_also = [
|
||||||
|
|
||||||
[tool.typos.default]
|
[tool.typos.default]
|
||||||
extend-words = { adaptee = "adaptee" }
|
extend-words = { adaptee = "adaptee" }
|
||||||
extend-ignore-identifiers-re = ["NDArray*"]
|
extend-ignore-identifiers-re = ["NDArray*", "interm"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
filterwarnings = [
|
filterwarnings = [
|
||||||
|
|
|
@ -23,7 +23,7 @@ from refiners.foundationals.segment_anything.hq_sam import (
|
||||||
MaskDecoderTokensExtender,
|
MaskDecoderTokensExtender,
|
||||||
PredictionsPostProc,
|
PredictionsPostProc,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
from refiners.foundationals.segment_anything.model import ImageEmbedding, SegmentAnythingH
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -248,8 +248,8 @@ def test_predictor(
|
||||||
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||||
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore
|
iou_predictions_np = torch.from_numpy(iou_predictions_np).to(dtype=torch.float32) # type: ignore
|
||||||
|
|
||||||
# NOTE: Diff on logits is relatively high, but on the same scale / even lower than base SAM logits diff (6e-3)
|
# NOTE: Diff on logits is relatively high,
|
||||||
# See https://github.com/finegrain-ai/refiners/blob/c6b5eb24a179d48e4542d94684a70c5ef3142ab1/tests/foundationals/segment_anything/test_sam.py#L426
|
# see test_predictor_equal for a stricter version
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
reference_low_res_mask_hq,
|
reference_low_res_mask_hq,
|
||||||
refiners_low_res_mask_hq,
|
refiners_low_res_mask_hq,
|
||||||
|
@ -265,6 +265,58 @@ def test_predictor(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("hq_mask_only", [True, False])
|
||||||
|
def test_predictor_equal(
|
||||||
|
sam_h: SegmentAnythingH,
|
||||||
|
hq_adapter_weights: Path,
|
||||||
|
hq_mask_only: bool,
|
||||||
|
reference_sam_h_predictor: FacebookSAMPredictorHQ,
|
||||||
|
tennis: Image.Image,
|
||||||
|
one_prompt: SAMPrompt,
|
||||||
|
) -> None:
|
||||||
|
adapter = HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
||||||
|
adapter.hq_mask_only = hq_mask_only
|
||||||
|
assert sam_h.ensure_find(PredictionsPostProc).hq_mask_only == hq_mask_only
|
||||||
|
|
||||||
|
# See in test_sam.py test_predictor_resized_single_output
|
||||||
|
# to do torch.equal we need to resize the image before
|
||||||
|
# and to use image_embedding as input
|
||||||
|
|
||||||
|
size = (1024, 1024)
|
||||||
|
resized_tennis = tennis.resize(size)
|
||||||
|
|
||||||
|
# Reference
|
||||||
|
reference_sam_h_predictor.set_image(np.array(resized_tennis))
|
||||||
|
|
||||||
|
predictor_prompt = one_prompt.__dict__["box_points"]
|
||||||
|
masks_np, _, low_res_masks_np = reference_sam_h_predictor.predict(
|
||||||
|
box=np.array(predictor_prompt).flatten(),
|
||||||
|
multimask_output=False,
|
||||||
|
hq_token_only=hq_mask_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
reference_high_res_mask_hq = torch.from_numpy(masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||||
|
reference_low_res_mask_hq = torch.from_numpy(low_res_masks_np[0, ...]).to(dtype=torch.float32) # type: ignore
|
||||||
|
|
||||||
|
# Refiners
|
||||||
|
|
||||||
|
# We bypass the refiners ViT by using directly the image features and interm_features
|
||||||
|
# from the reference implementation: this gives the ability to do torch.equal
|
||||||
|
reference_image_embedding = ImageEmbedding(features=reference_sam_h_predictor.features, original_image_size=size)
|
||||||
|
adapter.set_context("hq_sam", {"early_vit_embedding": reference_sam_h_predictor.interm_features[0]})
|
||||||
|
|
||||||
|
high_res_masks, _, low_res_masks = sam_h.predict(reference_image_embedding, **one_prompt.__dict__)
|
||||||
|
refiners_high_res_mask_hq = high_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
|
||||||
|
refiners_low_res_mask_hq = low_res_masks[0, 0, ...].to(dtype=torch.float32).detach().cpu()
|
||||||
|
|
||||||
|
assert torch.equal(
|
||||||
|
reference_low_res_mask_hq,
|
||||||
|
refiners_low_res_mask_hq,
|
||||||
|
)
|
||||||
|
assert torch.abs(reference_high_res_mask_hq - refiners_high_res_mask_hq).flatten().sum() == 0
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
|
def test_batch_mask_decoder(sam_h: SegmentAnythingH, hq_adapter_weights: Path) -> None:
|
||||||
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
HQSAMAdapter(sam_h, weights=load_from_safetensors(hq_adapter_weights)).inject()
|
||||||
|
|
|
@ -57,6 +57,8 @@ class FacebookSAMPredictor:
|
||||||
|
|
||||||
class FacebookSAMPredictorHQ:
|
class FacebookSAMPredictorHQ:
|
||||||
model: FacebookSAM
|
model: FacebookSAM
|
||||||
|
features: Tensor
|
||||||
|
interm_features: Tensor
|
||||||
|
|
||||||
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue