mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
fix .to for MVANet
This commit is contained in:
parent
5a1e1e70f8
commit
151b491831
|
@ -26,15 +26,14 @@ class PositionEmbeddingSine(fl.Module):
|
||||||
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
|
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_pos_feats: int, device: Device | None = None) -> None:
|
def __init__(self, num_pos_feats: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
|
||||||
temperature = 10000
|
temperature = 10000
|
||||||
self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32, device=self.device)
|
self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32)
|
||||||
self.dim_t = temperature ** (2 * (self.dim_t // 2) / num_pos_feats)
|
self.dim_t = temperature ** (2 * (self.dim_t // 2) / num_pos_feats)
|
||||||
|
|
||||||
def __call__(self, h: int, w: int) -> Tensor:
|
def __call__(self, h: int, w: int) -> Tensor:
|
||||||
mask = torch.ones([1, h, w, 1], dtype=torch.bool, device=self.device)
|
mask = torch.ones([1, h, w, 1], dtype=torch.bool)
|
||||||
y_embed = mask.cumsum(dim=1, dtype=torch.float32)
|
y_embed = mask.cumsum(dim=1, dtype=torch.float32)
|
||||||
x_embed = mask.cumsum(dim=2, dtype=torch.float32)
|
x_embed = mask.cumsum(dim=2, dtype=torch.float32)
|
||||||
|
|
||||||
|
@ -129,7 +128,7 @@ class MCLM(fl.Chain):
|
||||||
if pool_ratios is None:
|
if pool_ratios is None:
|
||||||
pool_ratios = [2, 8, 16]
|
pool_ratios = [2, 8, 16]
|
||||||
|
|
||||||
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device)
|
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2)
|
||||||
|
|
||||||
# LayerNorms in MCLM share their weights.
|
# LayerNorms in MCLM share their weights.
|
||||||
# We use the `proxy` trick below so they can be present only
|
# We use the `proxy` trick below so they can be present only
|
||||||
|
@ -174,6 +173,7 @@ class MCLM(fl.Chain):
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
fl.Lambda(lambda t1, t2: (*t1, *t2)), # type: ignore
|
fl.Lambda(lambda t1, t2: (*t1, *t2)), # type: ignore
|
||||||
|
fl.Converter(set_dtype=False),
|
||||||
GlobalAttention(emb_dim, num_heads, device=device),
|
GlobalAttention(emb_dim, num_heads, device=device),
|
||||||
ln1,
|
ln1,
|
||||||
FeedForward(emb_dim, device=device),
|
FeedForward(emb_dim, device=device),
|
||||||
|
|
|
@ -57,3 +57,25 @@ def test_mvanet(
|
||||||
prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid()
|
prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid()
|
||||||
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
|
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
|
||||||
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
|
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_mvanet_to(
|
||||||
|
mvanet_weights: Path,
|
||||||
|
ref_cactus: Image.Image,
|
||||||
|
expected_cactus_mask: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
if test_device.type == "cpu":
|
||||||
|
warn("not running on CPU, skipping")
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
model = MVANet(device=torch.device("cpu")).eval()
|
||||||
|
model.load_from_safetensors(mvanet_weights)
|
||||||
|
model.to(test_device)
|
||||||
|
|
||||||
|
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
|
||||||
|
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
|
||||||
|
prediction: torch.Tensor = model(in_t.to(test_device)).sigmoid()
|
||||||
|
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
|
||||||
|
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
|
||||||
|
|
Loading…
Reference in a new issue