mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
sam: wrap high-level methods with no_grad
This commit is contained in:
parent
e7892254eb
commit
22ce3fd033
|
@ -39,6 +39,7 @@ class SegmentAnything(fl.Module):
|
|||
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||
original_size = (image.height, image.width)
|
||||
target_size = self.compute_target_size(original_size)
|
||||
|
@ -47,6 +48,7 @@ class SegmentAnything(fl.Module):
|
|||
original_image_size=original_size,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict(
|
||||
self,
|
||||
input: Image.Image | ImageEmbedding,
|
||||
|
|
|
@ -289,7 +289,6 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N
|
|||
assert torch.equal(input=iou_prediction, other=facebook_prediction)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_predictor(
|
||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt
|
||||
) -> None:
|
||||
|
@ -312,7 +311,6 @@ def test_predictor(
|
|||
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:
|
||||
masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)
|
||||
|
||||
|
|
Loading…
Reference in a new issue