mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
fix .to for MVANet
This commit is contained in:
parent
5a1e1e70f8
commit
151b491831
|
@ -26,15 +26,14 @@ class PositionEmbeddingSine(fl.Module):
|
|||
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats: int, device: Device | None = None) -> None:
|
||||
def __init__(self, num_pos_feats: int) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
temperature = 10000
|
||||
self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32, device=self.device)
|
||||
self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32)
|
||||
self.dim_t = temperature ** (2 * (self.dim_t // 2) / num_pos_feats)
|
||||
|
||||
def __call__(self, h: int, w: int) -> Tensor:
|
||||
mask = torch.ones([1, h, w, 1], dtype=torch.bool, device=self.device)
|
||||
mask = torch.ones([1, h, w, 1], dtype=torch.bool)
|
||||
y_embed = mask.cumsum(dim=1, dtype=torch.float32)
|
||||
x_embed = mask.cumsum(dim=2, dtype=torch.float32)
|
||||
|
||||
|
@ -129,7 +128,7 @@ class MCLM(fl.Chain):
|
|||
if pool_ratios is None:
|
||||
pool_ratios = [2, 8, 16]
|
||||
|
||||
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device)
|
||||
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2)
|
||||
|
||||
# LayerNorms in MCLM share their weights.
|
||||
# We use the `proxy` trick below so they can be present only
|
||||
|
@ -174,6 +173,7 @@ class MCLM(fl.Chain):
|
|||
),
|
||||
),
|
||||
fl.Lambda(lambda t1, t2: (*t1, *t2)), # type: ignore
|
||||
fl.Converter(set_dtype=False),
|
||||
GlobalAttention(emb_dim, num_heads, device=device),
|
||||
ln1,
|
||||
FeedForward(emb_dim, device=device),
|
||||
|
|
|
@ -57,3 +57,25 @@ def test_mvanet(
|
|||
prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid()
|
||||
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
|
||||
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_mvanet_to(
|
||||
mvanet_weights: Path,
|
||||
ref_cactus: Image.Image,
|
||||
expected_cactus_mask: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
if test_device.type == "cpu":
|
||||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
model = MVANet(device=torch.device("cpu")).eval()
|
||||
model.load_from_safetensors(mvanet_weights)
|
||||
model.to(test_device)
|
||||
|
||||
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
|
||||
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
|
||||
prediction: torch.Tensor = model(in_t.to(test_device)).sigmoid()
|
||||
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
|
||||
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
|
||||
|
|
Loading…
Reference in a new issue