mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 15:48:46 +00:00
sam: wrap high-level methods with no_grad
This commit is contained in:
parent
e7892254eb
commit
22ce3fd033
|
@ -39,6 +39,7 @@ class SegmentAnything(fl.Module):
|
||||||
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||||
original_size = (image.height, image.width)
|
original_size = (image.height, image.width)
|
||||||
target_size = self.compute_target_size(original_size)
|
target_size = self.compute_target_size(original_size)
|
||||||
|
@ -47,6 +48,7 @@ class SegmentAnything(fl.Module):
|
||||||
original_image_size=original_size,
|
original_image_size=original_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
input: Image.Image | ImageEmbedding,
|
input: Image.Image | ImageEmbedding,
|
||||||
|
|
|
@ -289,7 +289,6 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N
|
||||||
assert torch.equal(input=iou_prediction, other=facebook_prediction)
|
assert torch.equal(input=iou_prediction, other=facebook_prediction)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def test_predictor(
|
def test_predictor(
|
||||||
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt
|
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -312,7 +311,6 @@ def test_predictor(
|
||||||
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)
|
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:
|
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:
|
||||||
masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)
|
masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue