fix .to for MVANet
Some checks are pending
CI / lint_and_typecheck (push) Waiting to run
Deploy docs to GitHub Pages / Deploy docs (push) Waiting to run
Spell checker / Spell check (push) Waiting to run

This commit is contained in:
Pierre Chapuis 2024-08-27 18:02:47 +02:00
parent 0046d2288f
commit e643006887
2 changed files with 27 additions and 5 deletions

View file

@ -26,15 +26,14 @@ class PositionEmbeddingSine(fl.Module):
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
"""
def __init__(self, num_pos_feats: int, device: Device | None = None) -> None:
def __init__(self, num_pos_feats: int) -> None:
super().__init__()
self.device = device
temperature = 10000
self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32, device=self.device)
self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32)
self.dim_t = temperature ** (2 * (self.dim_t // 2) / num_pos_feats)
def __call__(self, h: int, w: int) -> Tensor:
mask = torch.ones([1, h, w, 1], dtype=torch.bool, device=self.device)
mask = torch.ones([1, h, w, 1], dtype=torch.bool)
y_embed = mask.cumsum(dim=1, dtype=torch.float32)
x_embed = mask.cumsum(dim=2, dtype=torch.float32)
@ -129,7 +128,7 @@ class MCLM(fl.Chain):
if pool_ratios is None:
pool_ratios = [2, 8, 16]
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device)
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2)
# LayerNorms in MCLM share their weights.
# We use the `proxy` trick below so they can be present only
@ -174,6 +173,7 @@ class MCLM(fl.Chain):
),
),
fl.Lambda(lambda t1, t2: (*t1, *t2)), # type: ignore
fl.Converter(set_dtype=False),
GlobalAttention(emb_dim, num_heads, device=device),
ln1,
FeedForward(emb_dim, device=device),

View file

@ -57,3 +57,25 @@ def test_mvanet(
prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid()
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
@no_grad()
def test_mvanet_to(
mvanet_weights: Path,
ref_cactus: Image.Image,
expected_cactus_mask: Image.Image,
test_device: torch.device,
):
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
model = MVANet(device=torch.device("cpu")).eval()
model.load_from_safetensors(mvanet_weights)
model.to(test_device)
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
prediction: torch.Tensor = model(in_t.to(test_device)).sigmoid()
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))