From e643006887bef0b4df79c794ff22155647eeabf0 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 27 Aug 2024 18:02:47 +0200 Subject: [PATCH] fix .to for MVANet --- .../foundationals/swin/mvanet/mclm.py | 10 ++++----- tests/e2e/test_mvanet.py | 22 +++++++++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/refiners/foundationals/swin/mvanet/mclm.py b/src/refiners/foundationals/swin/mvanet/mclm.py index b0a21f7..84de4f3 100644 --- a/src/refiners/foundationals/swin/mvanet/mclm.py +++ b/src/refiners/foundationals/swin/mvanet/mclm.py @@ -26,15 +26,14 @@ class PositionEmbeddingSine(fl.Module): Non-trainable position embedding, originally from https://github.com/facebookresearch/detr """ - def __init__(self, num_pos_feats: int, device: Device | None = None) -> None: + def __init__(self, num_pos_feats: int) -> None: super().__init__() - self.device = device temperature = 10000 - self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32, device=self.device) + self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32) self.dim_t = temperature ** (2 * (self.dim_t // 2) / num_pos_feats) def __call__(self, h: int, w: int) -> Tensor: - mask = torch.ones([1, h, w, 1], dtype=torch.bool, device=self.device) + mask = torch.ones([1, h, w, 1], dtype=torch.bool) y_embed = mask.cumsum(dim=1, dtype=torch.float32) x_embed = mask.cumsum(dim=2, dtype=torch.float32) @@ -129,7 +128,7 @@ class MCLM(fl.Chain): if pool_ratios is None: pool_ratios = [2, 8, 16] - positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device) + positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2) # LayerNorms in MCLM share their weights. # We use the `proxy` trick below so they can be present only @@ -174,6 +173,7 @@ class MCLM(fl.Chain): ), ), fl.Lambda(lambda t1, t2: (*t1, *t2)), # type: ignore + fl.Converter(set_dtype=False), GlobalAttention(emb_dim, num_heads, device=device), ln1, FeedForward(emb_dim, device=device), diff --git a/tests/e2e/test_mvanet.py b/tests/e2e/test_mvanet.py index 08a47ad..71833ba 100644 --- a/tests/e2e/test_mvanet.py +++ b/tests/e2e/test_mvanet.py @@ -57,3 +57,25 @@ def test_mvanet( prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid() cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR) ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB")) + + +@no_grad() +def test_mvanet_to( + mvanet_weights: Path, + ref_cactus: Image.Image, + expected_cactus_mask: Image.Image, + test_device: torch.device, +): + if test_device.type == "cpu": + warn("not running on CPU, skipping") + pytest.skip() + + model = MVANet(device=torch.device("cpu")).eval() + model.load_from_safetensors(mvanet_weights) + model.to(test_device) + + in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze() + in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0) + prediction: torch.Tensor = model(in_t.to(test_device)).sigmoid() + cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR) + ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))