sam: wrap high-level methods with no_grad

This commit is contained in:
Cédric Deltheil 2023-12-19 19:32:31 +01:00 committed by Cédric Deltheil
parent e7892254eb
commit 22ce3fd033
2 changed files with 2 additions and 2 deletions

View file

@ -39,6 +39,7 @@ class SegmentAnything(fl.Module):
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype) self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype) self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
@torch.no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
original_size = (image.height, image.width) original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size) target_size = self.compute_target_size(original_size)
@ -47,6 +48,7 @@ class SegmentAnything(fl.Module):
original_image_size=original_size, original_image_size=original_size,
) )
@torch.no_grad()
def predict( def predict(
self, self,
input: Image.Image | ImageEmbedding, input: Image.Image | ImageEmbedding,

View file

@ -289,7 +289,6 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N
assert torch.equal(input=iou_prediction, other=facebook_prediction) assert torch.equal(input=iou_prediction, other=facebook_prediction)
@torch.no_grad()
def test_predictor( def test_predictor(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt
) -> None: ) -> None:
@ -312,7 +311,6 @@ def test_predictor(
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05) assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)
@torch.no_grad()
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None: def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:
masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__) masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)