From 22ce3fd033b1ce2cb117f4c90ca83a0914905a2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Deltheil?= Date: Tue, 19 Dec 2023 19:32:31 +0100 Subject: [PATCH] sam: wrap high-level methods with no_grad --- src/refiners/foundationals/segment_anything/model.py | 2 ++ tests/foundationals/segment_anything/test_sam.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 6f83f72..f8abfb7 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -39,6 +39,7 @@ class SegmentAnything(fl.Module): self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype) self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype) + @torch.no_grad() def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: original_size = (image.height, image.width) target_size = self.compute_target_size(original_size) @@ -47,6 +48,7 @@ class SegmentAnything(fl.Module): original_image_size=original_size, ) + @torch.no_grad() def predict( self, input: Image.Image | ImageEmbedding, diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 1e00e64..0c5fbf9 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -289,7 +289,6 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N assert torch.equal(input=iou_prediction, other=facebook_prediction) -@torch.no_grad() def test_predictor( facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt ) -> None: @@ -312,7 +311,6 @@ def test_predictor( assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05) -@torch.no_grad() def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None: masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)