diff --git a/docs/reference/SUMMARY.md b/docs/reference/SUMMARY.md
index 8338a24..76475d6 100644
--- a/docs/reference/SUMMARY.md
+++ b/docs/reference/SUMMARY.md
@@ -9,4 +9,4 @@
* [
DINOv2](foundationals/dinov2.md)
* [
Latent Diffusion](foundationals/latent_diffusion.md)
* [
Segment Anything](foundationals/segment_anything.md)
-
+ * [
Swin Transformers](foundationals/swin.md)
diff --git a/docs/reference/foundationals/swin.md b/docs/reference/foundationals/swin.md
new file mode 100644
index 0000000..02c39b3
--- /dev/null
+++ b/docs/reference/foundationals/swin.md
@@ -0,0 +1,2 @@
+::: refiners.foundationals.swin.swin_transformer
+::: refiners.foundationals.swin.mvanet
diff --git a/pyproject.toml b/pyproject.toml
index c3962e8..ee6909f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -58,6 +58,7 @@ conversion = [
"segment-anything-py>=1.0",
"requests>=2.26.0",
"tqdm>=4.62.3",
+ "gdown>=5.2.0",
]
doc = [
# required by mkdocs to format the signatures
diff --git a/requirements.lock b/requirements.lock
index a56a5a5..f9ec260 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -31,6 +31,8 @@ babel==2.15.0
# via mkdocs-material
backports-strenum==1.3.1
# via griffe
+beautifulsoup4==4.12.3
+ # via gdown
bitsandbytes==0.43.3
# via refiners
black==24.4.2
@@ -70,6 +72,7 @@ docker-pycreds==0.4.0
filelock==3.15.4
# via datasets
# via diffusers
+ # via gdown
# via huggingface-hub
# via torch
# via transformers
@@ -85,6 +88,8 @@ fsspec==2024.5.0
# via torch
future==1.0.0
# via neptune
+gdown==5.2.0
+ # via refiners
ghp-import==2.1.0
# via mkdocs
gitdb==4.0.11
@@ -274,6 +279,8 @@ pyjwt==2.9.0
pymdown-extensions==10.9
# via mkdocs-material
# via mkdocstrings
+pysocks==1.7.1
+ # via requests
python-dateutil==2.9.0.post0
# via arrow
# via botocore
@@ -311,6 +318,7 @@ requests==2.32.3
# via bravado-core
# via datasets
# via diffusers
+ # via gdown
# via huggingface-hub
# via mkdocs-material
# via neptune
@@ -356,6 +364,8 @@ six==1.16.0
# via rfc3339-validator
smmap==5.0.1
# via gitdb
+soupsieve==2.6
+ # via beautifulsoup4
swagger-spec-validator==3.0.4
# via bravado-core
# via neptune
@@ -383,6 +393,7 @@ torchvision==0.19.0
# via timm
tqdm==4.66.4
# via datasets
+ # via gdown
# via huggingface-hub
# via refiners
# via transformers
diff --git a/scripts/conversion/convert_mvanet.py b/scripts/conversion/convert_mvanet.py
new file mode 100644
index 0000000..e5b30e7
--- /dev/null
+++ b/scripts/conversion/convert_mvanet.py
@@ -0,0 +1,40 @@
+import argparse
+from pathlib import Path
+
+from refiners.fluxion.utils import load_tensors, save_to_safetensors
+from refiners.foundationals.swin.mvanet.converter import convert_weights
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--from",
+ type=str,
+ required=True,
+ dest="source_path",
+ help="A MVANet checkpoint. One can be found at https://github.com/qianyu-dlut/MVANet",
+ )
+ parser.add_argument(
+ "--to",
+ type=str,
+ dest="output_path",
+ default=None,
+ help=(
+ "Path to save the converted model. If not specified, the output path will be the source path with the"
+ " extension changed to .safetensors."
+ ),
+ )
+ parser.add_argument("--half", action="store_true", dest="half")
+ args = parser.parse_args()
+
+ src_weights = load_tensors(args.source_path)
+ weights = convert_weights(src_weights)
+ if args.half:
+ weights = {key: value.half() for key, value in weights.items()}
+ if args.output_path is None:
+ args.output_path = f"{Path(args.source_path).stem}.safetensors"
+ save_to_safetensors(path=args.output_path, tensors=weights)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py
index c20447c..6f27a67 100644
--- a/scripts/prepare_test_weights.py
+++ b/scripts/prepare_test_weights.py
@@ -11,6 +11,7 @@ import subprocess
import sys
from urllib.parse import urlparse
+import gdown
import requests
from tqdm import tqdm
@@ -446,6 +447,25 @@ def download_ic_light():
)
+def download_mvanet():
+ fn = "Model_80.pth"
+ dest_folder = os.path.join(test_weights_dir, "mvanet")
+ dest_filename = os.path.join(dest_folder, fn)
+
+ if os.environ.get("DRY_RUN") == "1":
+ return
+
+ if os.path.exists(dest_filename):
+ print(f"✖️ ️ Skipping previously downloaded mvanet/{fn}")
+ else:
+ os.makedirs(dest_folder, exist_ok=True)
+ print(f"🔽 Downloading mvanet/{fn} => '{rel(dest_filename)}'", end="\n")
+ gdown.download(id="1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv", output=dest_filename, quiet=True)
+ print(f"{previous_line}✅ Downloaded mvanet/{fn} => '{rel(dest_filename)}' ")
+
+ check_hash(dest_filename, "b915d492")
+
+
def printg(msg: str):
"""print in green color"""
print("\033[92m" + msg + "\033[0m")
@@ -808,6 +828,16 @@ def convert_ic_light():
)
+def convert_mvanet():
+ run_conversion_script(
+ "convert_mvanet.py",
+ "tests/weights/mvanet/Model_80.pth",
+ "tests/weights/mvanet/mvanet.safetensors",
+ half=True,
+ expected_hash="bf9ae4cb",
+ )
+
+
def download_all():
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
download_sd15("runwayml/stable-diffusion-v1-5")
@@ -830,6 +860,7 @@ def download_all():
download_sdxl_lightning_base()
download_sdxl_lightning_lora()
download_ic_light()
+ download_mvanet()
def convert_all():
@@ -850,6 +881,7 @@ def convert_all():
convert_lcm_base()
convert_sdxl_lightning_base()
convert_ic_light()
+ convert_mvanet()
def main():
diff --git a/src/refiners/foundationals/swin/__init__.py b/src/refiners/foundationals/swin/__init__.py
new file mode 100644
index 0000000..51cfb22
--- /dev/null
+++ b/src/refiners/foundationals/swin/__init__.py
@@ -0,0 +1,3 @@
+from .swin_transformer import SwinTransformer
+
+__all__ = ["SwinTransformer"]
diff --git a/src/refiners/foundationals/swin/mvanet/__init__.py b/src/refiners/foundationals/swin/mvanet/__init__.py
new file mode 100644
index 0000000..5c54590
--- /dev/null
+++ b/src/refiners/foundationals/swin/mvanet/__init__.py
@@ -0,0 +1,3 @@
+from .mvanet import MVANet
+
+__all__ = ["MVANet"]
diff --git a/src/refiners/foundationals/swin/mvanet/converter.py b/src/refiners/foundationals/swin/mvanet/converter.py
new file mode 100644
index 0000000..daacd69
--- /dev/null
+++ b/src/refiners/foundationals/swin/mvanet/converter.py
@@ -0,0 +1,138 @@
+import re
+
+from torch import Tensor
+
+
+def convert_weights(official_state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
+ rm_list = [
+ # Official weights contains useless keys
+ # See https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425
+ r"multifieldcrossatt.linear[56]",
+ r"multifieldcrossatt.attention.5",
+ r"dec_blk\d+\.linear[12]",
+ r"dec_blk[1234]\.attention\.[4567]",
+ # We don't need the sideout weights
+ r"sideout\d+",
+ ]
+ state_dict = {k: v for k, v in official_state_dict.items() if not any(re.match(rm, k) for rm in rm_list)}
+
+ keys_map: dict[str, str] = {}
+ for k in state_dict.keys():
+ v: str = k
+
+ def rpfx(s: str, src: str, dst: str) -> str:
+ if not s.startswith(src):
+ return s
+ return s.replace(src, dst, 1)
+
+ # Swin Transformer backbone
+
+ v = rpfx(v, "backbone.patch_embed.proj.", "SwinTransformer.PatchEmbedding.Conv2d.")
+ v = rpfx(v, "backbone.patch_embed.norm.", "SwinTransformer.PatchEmbedding.LayerNorm.")
+
+ if m := re.match(r"backbone\.layers\.(\d+)\.downsample\.(.*)", v):
+ s = m.group(2).replace("reduction.", "Linear.").replace("norm.", "LayerNorm.")
+ v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.PatchMerging.{s}"
+
+ if m := re.match(r"backbone\.layers\.(\d+)\.blocks\.(\d+)\.(.*)", v):
+ s = m.group(3)
+ s = s.replace("norm1.", "Residual_1.LayerNorm.")
+ s = s.replace("norm2.", "Residual_2.LayerNorm.")
+
+ s = s.replace("attn.qkv.", "Residual_1.WindowAttention.Linear_1.")
+ s = s.replace("attn.proj.", "Residual_1.WindowAttention.Linear_2.")
+ s = s.replace("attn.relative_position", "Residual_1.WindowAttention.WindowSDPA.rpb.relative_position")
+
+ s = s.replace("mlp.fc", "Residual_2.Linear_")
+ v = ".".join(
+ [
+ f"SwinTransformer.Chain_{int(m.group(1)) + 1}",
+ f"BasicLayer.SwinTransformerBlock_{int(m.group(2)) + 1}",
+ s,
+ ]
+ )
+
+ if m := re.match(r"backbone\.norm(\d+)\.(.*)", v):
+ v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.Passthrough.LayerNorm.{m.group(2)}"
+
+ # MVANet
+
+ def mclm(s: str, pfx_src: str, pfx_dst: str) -> str:
+ pca = f"{pfx_dst}Residual.PatchwiseCrossAttention"
+ s = rpfx(s, f"{pfx_src}linear1.", f"{pfx_dst}FeedForward_1.Linear_1.")
+ s = rpfx(s, f"{pfx_src}linear2.", f"{pfx_dst}FeedForward_1.Linear_2.")
+ s = rpfx(s, f"{pfx_src}linear3.", f"{pfx_dst}FeedForward_2.Linear_1.")
+ s = rpfx(s, f"{pfx_src}linear4.", f"{pfx_dst}FeedForward_2.Linear_2.")
+ s = rpfx(s, f"{pfx_src}norm1.", f"{pfx_dst}LayerNorm_1.")
+ s = rpfx(s, f"{pfx_src}norm2.", f"{pfx_dst}LayerNorm_2.")
+ s = rpfx(s, f"{pfx_src}attention.0.", f"{pfx_dst}GlobalAttention.Sum.Chain.MultiheadAttention.")
+ s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.")
+ s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.")
+ s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.")
+ s = rpfx(s, f"{pfx_src}attention.4.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.")
+ return s
+
+ def mcrm(s: str, pfx_src: str, pfx_dst: str) -> str:
+ # Note: there are no linear{1,2}, see https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425
+ tca = f"{pfx_dst}Parallel_3.TiledCrossAttention"
+ pca = f"{tca}.Sum.Chain_2.PatchwiseCrossAttention"
+ s = rpfx(s, f"{pfx_src}linear3.", f"{tca}.FeedForward.Linear_1.")
+ s = rpfx(s, f"{pfx_src}linear4.", f"{tca}.FeedForward.Linear_2.")
+ s = rpfx(s, f"{pfx_src}norm1.", f"{tca}.LayerNorm_1.")
+ s = rpfx(s, f"{pfx_src}norm2.", f"{tca}.LayerNorm_2.")
+ s = rpfx(s, f"{pfx_src}attention.0.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.")
+ s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.")
+ s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.")
+ s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.")
+ s = rpfx(s, f"{pfx_src}sal_conv.", f"{pfx_dst}Parallel_2.Multiply.Chain.Conv2d.")
+ return s
+
+ def cbr(s: str, pfx_src: str, pfx_dst: str, shift: int = 0) -> str:
+ s = rpfx(s, f"{pfx_src}{shift}.", f"{pfx_dst}Conv2d.")
+ s = rpfx(s, f"{pfx_src}{shift + 1}.", f"{pfx_dst}BatchNorm2d.")
+ s = rpfx(s, f"{pfx_src}{shift + 2}.", f"{pfx_dst}PReLU.")
+ return s
+
+ def cbg(s: str, pfx_src: str, pfx_dst: str) -> str:
+ s = rpfx(s, f"{pfx_src}0.", f"{pfx_dst}Conv2d.")
+ s = rpfx(s, f"{pfx_src}1.", f"{pfx_dst}BatchNorm2d.")
+ return s
+
+ v = rpfx(v, "shallow.0.", "ComputeShallow.Conv2d.")
+
+ v = cbr(v, "output1.", "Pyramid.Sum.Chain.CBR.")
+ v = cbr(v, "output2.", "Pyramid.Sum.PyramidL2.Sum.Chain.CBR.")
+ v = cbr(v, "output3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.Chain.CBR.")
+ v = cbr(v, "output4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.Chain.CBR.")
+ v = cbr(v, "output5.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.CBR.")
+
+ v = cbr(v, "conv1.", "Pyramid.CBR.")
+ v = cbr(v, "conv2.", "Pyramid.Sum.PyramidL2.CBR.")
+ v = cbr(v, "conv3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.CBR.")
+ v = cbr(v, "conv4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.CBR.")
+
+ v = mclm(v, "multifieldcrossatt.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.MCLM.")
+
+ v = mcrm(v, "dec_blk1.", "Pyramid.MCRM.")
+ v = mcrm(v, "dec_blk2.", "Pyramid.Sum.PyramidL2.MCRM.")
+ v = mcrm(v, "dec_blk3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.MCRM.")
+ v = mcrm(v, "dec_blk4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.MCRM.")
+
+ v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_1.")
+ v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_2.", shift=3)
+
+ v = rpfx(v, "insmask_head.6.", "RearrangeMultiView.Chain.Conv2d.")
+
+ v = cbg(v, "upsample1.", "ShallowUpscaler.Sum_2.Chain_1.CBG.")
+ v = cbg(v, "upsample2.", "ShallowUpscaler.CBG.")
+
+ v = rpfx(v, "output.0.", "Conv2d.")
+
+ if v != k:
+ keys_map[k] = v
+
+ for key, new_key in keys_map.items():
+ state_dict[new_key] = state_dict[key]
+ state_dict.pop(key)
+
+ return state_dict
diff --git a/src/refiners/foundationals/swin/mvanet/mclm.py b/src/refiners/foundationals/swin/mvanet/mclm.py
new file mode 100644
index 0000000..041308a
--- /dev/null
+++ b/src/refiners/foundationals/swin/mvanet/mclm.py
@@ -0,0 +1,211 @@
+# Multi-View Complementary Localization
+
+import math
+
+import torch
+from torch import Tensor, device as Device
+
+import refiners.fluxion.layers as fl
+from refiners.fluxion.context import Contexts
+
+from .utils import FeedForward, MultiheadAttention, MultiPool, PatchMerge, PatchwiseCrossAttention, Unflatten
+
+
+class PerPixel(fl.Chain):
+ """(B, C, H, W) -> H*W, B, C"""
+
+ def __init__(self):
+ super().__init__(
+ fl.Permute(2, 3, 0, 1),
+ fl.Flatten(0, 1),
+ )
+
+
+class PositionEmbeddingSine(fl.Module):
+ """
+ Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
+ """
+
+ def __init__(self, num_pos_feats: int, device: Device | None = None):
+ super().__init__()
+ self.device = device
+ temperature = 10000
+ self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32, device=self.device)
+ self.dim_t = temperature ** (2 * (self.dim_t // 2) / num_pos_feats)
+
+ def __call__(self, h: int, w: int) -> Tensor:
+ mask = torch.ones([1, h, w, 1], dtype=torch.bool, device=self.device)
+ y_embed = mask.cumsum(dim=1, dtype=torch.float32)
+ x_embed = mask.cumsum(dim=2, dtype=torch.float32)
+
+ eps, scale = 1e-6, 2 * math.pi
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * scale
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * scale
+
+ pos_x = x_embed / self.dim_t
+ pos_y = y_embed / self.dim_t
+
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ return torch.cat((pos_y, pos_x), dim=3).permute(1, 2, 0, 3).flatten(0, 1)
+
+
+class MultiPoolPos(fl.Module):
+ def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine):
+ super().__init__()
+ self.pool_ratios = pool_ratios
+ self.positional_embedding = positional_embedding
+
+ def forward(self, *args: int) -> Tensor:
+ h, w = args
+ return torch.cat([self.positional_embedding(h // ratio, w // ratio) for ratio in self.pool_ratios])
+
+
+class Repeat(fl.Module):
+ def __init__(self, dim: int = 0):
+ self.dim = dim
+ super().__init__()
+
+ def forward(self, x: Tensor, n: int) -> Tensor:
+ return torch.repeat_interleave(x, n, dim=self.dim)
+
+
+class _MHA_Arg(fl.Sum):
+ def __init__(self, offset: int):
+ self.offset = offset
+ super().__init__(
+ fl.GetArg(offset), # value
+ fl.Chain(
+ fl.Parallel(
+ fl.GetArg(self.offset + 1), # position embedding
+ fl.Lambda(self._batch_size),
+ ),
+ Repeat(1),
+ ),
+ )
+
+ def _batch_size(self, *args: Tensor) -> int:
+ return args[self.offset].size(1)
+
+
+class GlobalAttention(fl.Chain):
+ # Input must be a 4-tuple: (global, global pos. emb, pools, pools pos. emb.)
+ def __init__(
+ self,
+ emb_dim: int,
+ num_heads: int = 1,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.Sum(
+ fl.GetArg(0), # global
+ fl.Chain(
+ fl.Parallel(
+ _MHA_Arg(0), # Q: global + pos. emb
+ _MHA_Arg(2), # K: pools + pos. emb
+ fl.GetArg(2), # V: pools
+ ),
+ MultiheadAttention(emb_dim, num_heads, device=device),
+ ),
+ ),
+ )
+
+
+class MCLM(fl.Chain):
+ """Multi-View Complementary Localization Module
+ Inputs:
+ tensor: (b, 5, e, h, h)
+ Outputs:
+ tensor: (b, 5, e, h, h)
+ """
+
+ def __init__(
+ self,
+ emb_dim: int,
+ num_heads: int = 1,
+ pool_ratios: list[int] | None = None,
+ device: Device | None = None,
+ ):
+ if pool_ratios is None:
+ pool_ratios = [2, 8, 16]
+
+ positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device)
+
+ # LayerNorms in MCLM share their weights.
+
+ ln1 = fl.LayerNorm(emb_dim, device=device)
+ ln2 = fl.LayerNorm(emb_dim, device=device)
+
+ def proxy(m: fl.Module) -> fl.Module:
+ def f(x: Tensor) -> Tensor:
+ return m(x)
+
+ return fl.Lambda(f)
+
+ super().__init__(
+ fl.Parallel(
+ fl.Chain( # global
+ fl.Slicing(dim=1, start=4),
+ fl.Squeeze(1),
+ fl.Parallel(
+ PerPixel(), # glb
+ fl.Chain( # g_pos
+ fl.Lambda(lambda x: x.shape[-2:]), # type: ignore
+ positional_embedding,
+ ),
+ ),
+ ),
+ fl.Chain( # local
+ fl.Slicing(dim=1, end=4),
+ fl.SetContext("mclm", "local"),
+ PatchMerge(),
+ fl.Parallel(
+ fl.Chain( # pool
+ MultiPool(pool_ratios),
+ fl.Squeeze(0),
+ ),
+ fl.Chain( # pool_pos
+ fl.Lambda(lambda x: x.shape[-2:]), # type: ignore
+ MultiPoolPos(pool_ratios, positional_embedding),
+ ),
+ ),
+ ),
+ ),
+ fl.Lambda(lambda t1, t2: (*t1, *t2)), # type: ignore
+ GlobalAttention(emb_dim, num_heads, device=device),
+ ln1,
+ FeedForward(emb_dim, device=device),
+ ln2,
+ fl.SetContext("mclm", "global"),
+ fl.UseContext("mclm", "local"),
+ fl.Flatten(-2, -1),
+ fl.Permute(1, 3, 0, 2),
+ fl.Residual(
+ fl.Parallel(
+ fl.Identity(),
+ fl.Chain(
+ fl.UseContext("mclm", "global"),
+ Unflatten(0, (2, 8, 2, 8)), # 2, h/2, 2, h/2
+ fl.Permute(0, 2, 1, 3, 4, 5),
+ fl.Flatten(0, 1),
+ fl.Flatten(1, 2),
+ ),
+ ),
+ PatchwiseCrossAttention(emb_dim, num_heads, device=device),
+ ),
+ proxy(ln1),
+ FeedForward(emb_dim, device=device),
+ proxy(ln2),
+ fl.Concatenate(
+ fl.Identity(),
+ fl.Chain(
+ fl.UseContext("mclm", "global"),
+ fl.Unsqueeze(0),
+ ),
+ ),
+ Unflatten(1, (16, 16)), # h, h
+ fl.Permute(3, 0, 4, 1, 2),
+ )
+
+ def init_context(self) -> Contexts:
+ return {"mclm": {"global": None, "local": None}}
diff --git a/src/refiners/foundationals/swin/mvanet/mcrm.py b/src/refiners/foundationals/swin/mvanet/mcrm.py
new file mode 100644
index 0000000..01311da
--- /dev/null
+++ b/src/refiners/foundationals/swin/mvanet/mcrm.py
@@ -0,0 +1,119 @@
+# Multi-View Complementary Refinement
+
+import torch
+from torch import Tensor, device as Device
+
+import refiners.fluxion.layers as fl
+
+from .utils import FeedForward, Interpolate, MultiPool, PatchMerge, PatchSplit, PatchwiseCrossAttention, Unflatten
+
+
+class Multiply(fl.Chain):
+ def __init__(self, o1: fl.Module, o2: fl.Module) -> None:
+ super().__init__(o1, o2)
+
+ def forward(self, *args: Tensor) -> Tensor:
+ return torch.mul(self[0](*args), self[1](*args))
+
+
+class TiledCrossAttention(fl.Chain):
+ def __init__(
+ self,
+ emb_dim: int,
+ dim: int,
+ num_heads: int = 1,
+ pool_ratios: list[int] | None = None,
+ device: Device | None = None,
+ ):
+ # Input must be a 4-tuple: (local, global)
+
+ if pool_ratios is None:
+ pool_ratios = [1, 2, 4]
+
+ super().__init__(
+ fl.Distribute(
+ fl.Chain( # local
+ fl.Flatten(-2, -1),
+ fl.Permute(1, 3, 0, 2),
+ ),
+ fl.Chain( # global
+ PatchSplit(),
+ fl.Squeeze(0),
+ MultiPool(pool_ratios),
+ ),
+ ),
+ fl.Sum(
+ fl.Chain(
+ fl.GetArg(0),
+ fl.Permute(2, 1, 0, 3),
+ ),
+ fl.Chain(
+ PatchwiseCrossAttention(emb_dim, num_heads, device=device),
+ fl.Permute(2, 1, 0, 3),
+ ),
+ ),
+ fl.LayerNorm(emb_dim, device=device),
+ FeedForward(emb_dim, device=device),
+ fl.LayerNorm(emb_dim, device=device),
+ fl.Permute(0, 2, 3, 1),
+ Unflatten(-1, (dim, dim)),
+ )
+
+
+class MCRM(fl.Chain):
+ """Multi-View Complementary Refinement"""
+
+ def __init__(
+ self,
+ emb_dim: int,
+ size: int,
+ num_heads: int = 1,
+ pool_ratios: list[int] | None = None,
+ device: Device | None = None,
+ ):
+ if pool_ratios is None:
+ pool_ratios = [1, 2, 4]
+
+ super().__init__(
+ fl.Parallel(
+ fl.Chain( # local
+ fl.Slicing(dim=1, end=4),
+ ),
+ fl.Chain( # global
+ fl.Slicing(dim=1, start=4),
+ fl.Squeeze(1),
+ ),
+ ),
+ fl.Parallel(
+ Multiply(
+ fl.GetArg(0),
+ fl.Chain(
+ fl.GetArg(1),
+ fl.Conv2d(emb_dim, 1, 1, device=device),
+ fl.Sigmoid(),
+ Interpolate((size * 2, size * 2), "nearest"),
+ PatchSplit(),
+ ),
+ ),
+ fl.GetArg(1),
+ ),
+ fl.Parallel(
+ TiledCrossAttention(emb_dim, size, num_heads, pool_ratios, device=device),
+ fl.GetArg(1),
+ ),
+ fl.Concatenate(
+ fl.GetArg(0),
+ fl.Chain(
+ fl.Sum(
+ fl.GetArg(1),
+ fl.Chain(
+ fl.GetArg(0),
+ PatchMerge(),
+ Interpolate((size, size), "nearest"),
+ ),
+ ),
+ fl.Unsqueeze(1),
+ ),
+ dim=1,
+ ),
+ )
diff --git a/src/refiners/foundationals/swin/mvanet/mvanet.py b/src/refiners/foundationals/swin/mvanet/mvanet.py
new file mode 100644
index 0000000..4410eb6
--- /dev/null
+++ b/src/refiners/foundationals/swin/mvanet/mvanet.py
@@ -0,0 +1,337 @@
+# Multi-View Aggregation Network (arXiv:2404.07445)
+
+from torch import device as Device
+
+import refiners.fluxion.layers as fl
+from refiners.fluxion.context import Contexts
+from refiners.foundationals.swin.swin_transformer import SwinTransformer
+
+from .mclm import MCLM # Multi-View Complementary Localization
+from .mcrm import MCRM # Multi-View Complementary Refinement
+from .utils import BatchNorm2d, Interpolate, PatchMerge, PatchSplit, PReLU, Rescale, Unflatten
+
+
+class CBG(fl.Chain):
+ """(C)onvolution + (B)atchNorm + (G)eLU"""
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int | None = None,
+ device: Device | None = None,
+ ):
+ out_dim = out_dim or in_dim
+ super().__init__(
+ fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
+ BatchNorm2d(out_dim, device=device),
+ fl.GeLU(),
+ )
+
+
+class CBR(fl.Chain):
+ """(C)onvolution + (B)atchNorm + Parametric (R)eLU"""
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int | None = None,
+ device: Device | None = None,
+ ):
+ out_dim = out_dim or in_dim
+ super().__init__(
+ fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
+ BatchNorm2d(out_dim, device=device),
+ PReLU(device=device),
+ )
+
+
+class SplitMultiView(fl.Chain):
+ """
+ Split a hd tensor into 5 ld views, (5 = 1 global + 4 tiles)
+ See also the reverse Module [`RearrangeMultiView`][refiners.foundationals.swin.mvanet.RearrangeMultiView]
+
+ Inputs:
+ single_view (b, c, H, W)
+
+ Outputs:
+ multi_view (b, 5, c, H/2, W/2)
+ """
+
+ def __init__(self):
+ super().__init__(
+ fl.Concatenate(
+ PatchSplit(), # global features
+ fl.Chain( # local features
+ Rescale(scale_factor=0.5, mode="bilinear"),
+ fl.Unsqueeze(1),
+ ),
+ dim=1,
+ )
+ )
+
+
+class ShallowUpscaler(fl.Chain):
+ """4x Upscaler reusing the image as input to upscale the feature
+ See [[arXiv:2108.10257] SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/abs/2108.10257)
+
+ Args:
+ embedding_dim (int): the embedding dimension
+
+ Inputs:
+ feature (b, E, image_size/4, image_size/4)
+
+ Output:
+ upscaled tensor (b, E, image_size, image_size)
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.Sum(
+ fl.Identity(),
+ fl.Chain(
+ fl.UseContext("mvanet", "shallow"),
+ Interpolate((256, 256)),
+ ),
+ ),
+ fl.Sum(
+ fl.Chain(
+ Rescale(2),
+ CBG(embedding_dim, device=device),
+ ),
+ fl.Chain(
+ fl.UseContext("mvanet", "shallow"),
+ Interpolate((512, 512)),
+ ),
+ ),
+ Rescale(2),
+ CBG(embedding_dim, device=device),
+ )
+
+
+class PyramidL5(fl.Chain):
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.GetArg(0), # output5
+ fl.Flatten(0, 1),
+ CBR(1024, embedding_dim, device=device),
+ Unflatten(0, (-1, 5)),
+ MCLM(embedding_dim, device=device),
+ fl.Flatten(0, 1),
+ Interpolate((32, 32)),
+ )
+
+
+class PyramidL4(fl.Chain):
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.Sum(
+ PyramidL5(embedding_dim=embedding_dim, device=device),
+ fl.Chain(
+ fl.GetArg(1),
+ fl.Flatten(0, 1),
+ CBR(512, embedding_dim, device=device), # output4
+ Unflatten(0, (-1, 5)),
+ ),
+ ),
+ MCRM(embedding_dim, 32, device=device), # dec_blk4
+ fl.Flatten(0, 1),
+ CBR(embedding_dim, device=device), # conv4
+ Interpolate((64, 64)),
+ )
+
+
+class PyramidL3(fl.Chain):
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.Sum(
+ PyramidL4(embedding_dim=embedding_dim, device=device),
+ fl.Chain(
+ fl.GetArg(2),
+ fl.Flatten(0, 1),
+ CBR(256, embedding_dim, device=device), # output3
+ Unflatten(0, (-1, 5)),
+ ),
+ ),
+ MCRM(embedding_dim, 64, device=device), # dec_blk3
+ fl.Flatten(0, 1),
+ CBR(embedding_dim, device=device), # conv3
+ Interpolate((128, 128)),
+ )
+
+
+class PyramidL2(fl.Chain):
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ device: Device | None = None,
+ ):
+ embedding_dim = 128
+ super().__init__(
+ fl.Sum(
+ PyramidL3(embedding_dim=embedding_dim, device=device),
+ fl.Chain(
+ fl.GetArg(3),
+ fl.Flatten(0, 1),
+ CBR(128, embedding_dim, device=device), # output2
+ Unflatten(0, (-1, 5)),
+ ),
+ ),
+ MCRM(embedding_dim, 128, device=device), # dec_blk2
+ fl.Flatten(0, 1),
+ CBR(embedding_dim, device=device), # conv2
+ Interpolate((128, 128)),
+ )
+
+
+class Pyramid(fl.Chain):
+ """
+ Recursive Pyramidal Network calling MCLM and MCRM blocks
+
+ It acts as a FPN (Feature Pyramid Network) Neck for MVANet
+ see [[arXiv:1612.03144] Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
+
+ Inputs:
+ features: a pyramid of N = 5 tensors
+ shapes are (b, 5, E_{0}, S_{0}, S_{0}), ..., (b, 5, E_{1}, S_{i}, S_{i}), ..., (b, 5, E_{N-1}, S_{N-1}, S_{N-1})
+ with S_{i} = S_{i-1} or S_{i} = 2*S_{i-1} for 0 < i < N
+
+ Outputs:
+ output (b, 5, E, S_{N-1}, S_{N-1})
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.Sum(
+ PyramidL2(embedding_dim=embedding_dim, device=device),
+ fl.Chain(
+ fl.GetArg(4),
+ fl.Flatten(0, 1),
+ CBR(128, embedding_dim, device=device), # output1
+ Unflatten(0, (-1, 5)),
+ ),
+ ),
+ MCRM(embedding_dim, 128, device=device), # dec_blk1
+ fl.Flatten(0, 1),
+ CBR(embedding_dim, device=device), # conv1
+ Unflatten(0, (-1, 5)),
+ )
+
+
+class RearrangeMultiView(fl.Chain):
+ """
+ Inputs:
+ multi_view (b, 5, E, H, W)
+
+ Outputs:
+ single_view (b, E, H*2, W*2)
+
+ Fusion a multi view tensor into a single view tensor, using convolutions
+ See also the reverse Module [`SplitMultiView`][refiners.foundationals.swin.mvanet.SplitMultiView]
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.Sum(
+ fl.Chain( # local features
+ fl.Slicing(dim=1, end=4),
+ PatchMerge(),
+ ),
+ fl.Chain( # global feature
+ fl.Slicing(dim=1, start=4),
+ fl.Squeeze(1),
+ Interpolate((256, 256)),
+ ),
+ ),
+ fl.Chain( # conv head
+ CBR(embedding_dim, 384, device=device),
+ CBR(384, device=device),
+ fl.Conv2d(384, embedding_dim, kernel_size=3, padding=1, device=device),
+ ),
+ )
+
+
+class ComputeShallow(fl.Passthrough):
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device),
+ fl.SetContext("mvanet", "shallow"),
+ )
+
+
+class MVANet(fl.Chain):
+ """Multi-view Aggregation Network for Dichotomous Image Segmentation
+
+ See [[arXiv:2404.07445] Multi-view Aggregation Network for Dichotomous Image Segmentation](https://arxiv.org/abs/2404.07445) for more details.
+
+ Args:
+ embedding_dim (int): embedding dimension
+ n_logits (int): the number of output logits (default to 1)
+ 1 logit is used for alpha matting/foreground-background segmentation/sod segmentation
+ depths (list[int]): see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer]
+ num_heads (list[int]): see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer]
+ window_size (int): default to 12, see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer]
+ device (Device | None): the device to use
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int = 128,
+ n_logits: int = 1,
+ depths: list[int] | None = None,
+ num_heads: list[int] | None = None,
+ window_size: int = 12,
+ device: Device | None = None,
+ ):
+ if depths is None:
+ depths = [2, 2, 18, 2]
+ if num_heads is None:
+ num_heads = [4, 8, 16, 32]
+
+ super().__init__(
+ ComputeShallow(embedding_dim=embedding_dim, device=device),
+ SplitMultiView(),
+ fl.Flatten(0, 1),
+ SwinTransformer(
+ embedding_dim=embedding_dim,
+ depths=depths,
+ num_heads=num_heads,
+ window_size=window_size,
+ device=device,
+ ),
+ fl.Distribute(*(Unflatten(0, (-1, 5)) for _ in range(5))),
+ Pyramid(embedding_dim=embedding_dim, device=device),
+ RearrangeMultiView(embedding_dim=embedding_dim, device=device),
+ ShallowUpscaler(embedding_dim, device=device),
+ fl.Conv2d(embedding_dim, n_logits, kernel_size=3, padding=1, device=device),
+ )
+
+ def init_context(self) -> Contexts:
+ return {"mvanet": {"shallow": None}}
diff --git a/src/refiners/foundationals/swin/mvanet/utils.py b/src/refiners/foundationals/swin/mvanet/utils.py
new file mode 100644
index 0000000..363b55a
--- /dev/null
+++ b/src/refiners/foundationals/swin/mvanet/utils.py
@@ -0,0 +1,173 @@
+import torch
+from torch import Size, Tensor
+from torch.nn.functional import (
+ adaptive_avg_pool2d,
+ interpolate, # type: ignore
+)
+
+import refiners.fluxion.layers as fl
+
+
+class Unflatten(fl.Module):
+ def __init__(self, dim: int, sizes: tuple[int, ...]) -> None:
+ super().__init__()
+ self.dim = dim
+ self.sizes = Size(sizes)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.unflatten(input=x, dim=self.dim, sizes=self.sizes)
+
+
+class Interpolate(fl.Module):
+ def __init__(self, size: tuple[int, ...], mode: str = "bilinear"):
+ super().__init__()
+ self.size = Size(size)
+ self.mode = mode
+
+ def forward(self, x: Tensor) -> Tensor:
+ return interpolate(x, size=self.size, mode=self.mode) # type: ignore
+
+
+class Rescale(fl.Module):
+ def __init__(self, scale_factor: float, mode: str = "nearest"):
+ super().__init__()
+ self.scale_factor = scale_factor
+ self.mode = mode
+
+ def forward(self, x: Tensor) -> Tensor:
+ return interpolate(x, scale_factor=self.scale_factor, mode=self.mode) # type: ignore
+
+
+class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule):
+ def __init__(self, num_features: int, device: torch.device | None = None):
+ super().__init__(num_features=num_features, device=device) # type: ignore
+
+
+class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation):
+ def __init__(self, device: torch.device | None = None):
+ super().__init__(device=device) # type: ignore
+
+
+class PatchSplit(fl.Chain):
+ """(B, N, H, W) -> B, 4, N, H/2, W/2"""
+
+ def __init__(self):
+ super().__init__(
+ Unflatten(-2, (2, -1)),
+ Unflatten(-1, (2, -1)),
+ fl.Permute(0, 2, 4, 1, 3, 5),
+ fl.Flatten(1, 2),
+ )
+
+
+class PatchMerge(fl.Chain):
+ """B, 4, N, H, W -> (B, N, 2*H, 2*W)"""
+
+ def __init__(self):
+ super().__init__(
+ Unflatten(1, (2, 2)),
+ fl.Permute(0, 3, 1, 4, 2, 5),
+ fl.Flatten(-2, -1),
+ fl.Flatten(-3, -2),
+ )
+
+
+class FeedForward(fl.Residual):
+ def __init__(self, emb_dim: int, device: torch.device | None = None) -> None:
+ super().__init__(
+ fl.Linear(in_features=emb_dim, out_features=2 * emb_dim, device=device),
+ fl.ReLU(),
+ fl.Linear(in_features=2 * emb_dim, out_features=emb_dim, device=device),
+ )
+
+
+class _GetArgs(fl.Parallel):
+ def __init__(self, n: int):
+ super().__init__(
+ fl.Chain(
+ fl.GetArg(0),
+ fl.Slicing(dim=0, start=n, end=n + 1),
+ fl.Squeeze(0),
+ ),
+ fl.Chain(
+ fl.GetArg(1),
+ fl.Slicing(dim=0, start=n, end=n + 1),
+ fl.Squeeze(0),
+ ),
+ fl.Chain(
+ fl.GetArg(1),
+ fl.Slicing(dim=0, start=n, end=n + 1),
+ fl.Squeeze(0),
+ ),
+ )
+
+
+class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule):
+ def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None):
+ super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore
+
+ @property
+ def weight(self) -> Tensor: # type: ignore
+ return self.in_proj_weight
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # type: ignore
+ return super().forward(q, k, v)[0]
+
+
+class PatchwiseCrossAttention(fl.Chain):
+ # Input is 2 tensors of sizes (4, HW, B, C) and (4, HW', B, C),
+ # output is size (4, HW, B, C).
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ device: torch.device | None = None,
+ ):
+ super().__init__(
+ fl.Concatenate(
+ fl.Chain(
+ _GetArgs(0),
+ MultiheadAttention(d_model, num_heads, device=device),
+ ),
+ fl.Chain(
+ _GetArgs(1),
+ MultiheadAttention(d_model, num_heads, device=device),
+ ),
+ fl.Chain(
+ _GetArgs(2),
+ MultiheadAttention(d_model, num_heads, device=device),
+ ),
+ fl.Chain(
+ _GetArgs(3),
+ MultiheadAttention(d_model, num_heads, device=device),
+ ),
+ ),
+ Unflatten(0, (4, -1)),
+ )
+
+
+class Pool(fl.Module):
+ def __init__(self, ratio: int) -> None:
+ super().__init__()
+ self.ratio = ratio
+
+ def forward(self, x: Tensor) -> Tensor:
+ b, _, h, w = x.shape
+ assert h % self.ratio == 0 and w % self.ratio == 0
+ r = adaptive_avg_pool2d(x, (h // self.ratio, w // self.ratio))
+ return torch.unflatten(r, 0, (b, -1))
+
+
+class MultiPool(fl.Concatenate):
+ def __init__(self, pool_ratios: list[int]) -> None:
+ super().__init__(
+ *(
+ fl.Chain(
+ Pool(pool_ratio),
+ fl.Flatten(-2, -1),
+ fl.Permute(0, 3, 1, 2),
+ )
+ for pool_ratio in pool_ratios
+ ),
+ dim=1,
+ )
diff --git a/src/refiners/foundationals/swin/swin_transformer.py b/src/refiners/foundationals/swin/swin_transformer.py
new file mode 100644
index 0000000..f1291fe
--- /dev/null
+++ b/src/refiners/foundationals/swin/swin_transformer.py
@@ -0,0 +1,391 @@
+# Swin Transformer (arXiv:2103.14030)
+#
+# Specific to MVANet, only supports square inputs.
+# Originally adapted from the version in MVANet and InSPyReNet (https://github.com/plemeri/InSPyReNet)
+# Original implementation by Microsoft at https://github.com/microsoft/Swin-Transformer
+
+import functools
+from math import isqrt
+
+import torch
+from torch import Tensor, device as Device
+
+import refiners.fluxion.layers as fl
+from refiners.fluxion.context import Contexts
+
+
+def to_windows(x: Tensor, window_size: int) -> Tensor:
+ B, H, W, C = x.shape
+ assert W == H and H % window_size == 0
+ x = x.reshape(B, H // window_size, window_size, W // window_size, window_size, C)
+ return x.permute(0, 1, 3, 2, 4, 5).reshape(B, -1, window_size * window_size, C)
+
+
+class ToWindows(fl.Module):
+ def __init__(self, window_size: int):
+ super().__init__()
+ self.window_size = window_size
+
+ def forward(self, x: Tensor) -> Tensor:
+ return to_windows(x, self.window_size)
+
+
+class FromWindows(fl.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ B, num_windows, window_size_2, C = x.shape
+ window_size = isqrt(window_size_2)
+ H = isqrt(num_windows * window_size_2)
+ x = x.reshape(B, H // window_size, H // window_size, window_size, window_size, C)
+ return x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, H, C)
+
+
+@functools.cache
+def get_attn_mask(H: int, window_size: int, device: Device | None = None) -> Tensor:
+ assert H % window_size == 0
+ shift_size = window_size // 2
+ img_mask = torch.zeros((1, H, H, 1), device=device)
+ h_slices = (
+ slice(0, -window_size),
+ slice(-window_size, -shift_size),
+ slice(-shift_size, None),
+ )
+ w_slices = (
+ slice(0, -window_size),
+ slice(-window_size, -shift_size),
+ slice(-shift_size, None),
+ )
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = to_windows(img_mask, window_size).squeeze() # B, nW, window_size * window_size, [1]
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask.masked_fill_(attn_mask != 0, -100.0).masked_fill_(attn_mask == 0, 0.0)
+ return attn_mask
+
+
+class Pad(fl.Module):
+ def __init__(self, step: int):
+ super().__init__()
+ self.step = step
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, H, W, C = x.shape
+ assert W == H
+ if H % self.step == 0:
+ return x
+ p = self.step * ((H + self.step - 1) // self.step)
+ padded = torch.zeros(B, p, p, C, device=x.device, dtype=x.dtype)
+ padded[:, :H, :H, :] = x
+ return padded
+
+
+class StatefulPad(fl.Chain):
+ def __init__(self, context: str, key: str, step: int) -> None:
+ super().__init__(
+ fl.SetContext(context=context, key=key, callback=self._push),
+ Pad(step=step),
+ )
+
+ def _push(self, sizes: list[int], x: Tensor) -> None:
+ sizes.append(x.size(1))
+
+
+class StatefulUnpad(fl.Chain):
+ def __init__(self, context: str, key: str) -> None:
+ super().__init__(
+ fl.Parallel(
+ fl.Identity(),
+ fl.UseContext(context=context, key=key).compose(lambda x: x.pop()),
+ ),
+ fl.Lambda(self._unpad),
+ )
+
+ @staticmethod
+ def _unpad(x: Tensor, size: int) -> Tensor:
+ return x[:, :size, :size, :]
+
+
+class SquareUnflatten(fl.Module):
+ # ..., L^2, ... -> ..., L, L, ...
+
+ def __init__(self, dim: int = 0) -> None:
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x: Tensor) -> Tensor:
+ d = isqrt(x.shape[self.dim])
+ return torch.unflatten(x, self.dim, (d, d))
+
+
+class WindowUnflatten(fl.Module):
+ # ..., H, ... -> ..., H // ws, ws, ...
+
+ def __init__(self, window_size: int, dim: int = 0) -> None:
+ super().__init__()
+ self.window_size = window_size
+ self.dim = dim
+
+ def forward(self, x: Tensor) -> Tensor:
+ assert x.shape[self.dim] % self.window_size == 0
+ H = x.shape[self.dim]
+ return torch.unflatten(x, self.dim, (H // self.window_size, self.window_size))
+
+
+class Roll(fl.Module):
+ def __init__(self, *shifts: tuple[int, int]):
+ super().__init__()
+ self.shifts = shifts
+ self._dims = tuple(s[0] for s in shifts)
+ self._shifts = tuple(s[1] for s in shifts)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.roll(x, self._shifts, self._dims)
+
+
+class RelativePositionBias(fl.Module):
+ relative_position_index: Tensor
+
+ def __init__(self, window_size: int, num_heads: int, device: Device | None = None):
+ super().__init__()
+ self.relative_position_bias_table = torch.nn.Parameter(
+ torch.empty(
+ (2 * window_size - 1) * (2 * window_size - 1),
+ num_heads,
+ device=device,
+ )
+ )
+ relative_position_index = torch.empty(
+ window_size**2,
+ window_size**2,
+ device=device,
+ dtype=torch.int64,
+ )
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ def forward(self) -> Tensor:
+ # Yes, this is a (trainable) constant.
+ return self.relative_position_bias_table[self.relative_position_index].permute(2, 0, 1).unsqueeze(0)
+
+
+class WindowSDPA(fl.Module):
+ def __init__(
+ self,
+ dim: int,
+ window_size: int,
+ num_heads: int,
+ shift: bool = False,
+ device: Device | None = None,
+ ):
+ super().__init__()
+ self.window_size = window_size
+ self.num_heads = num_heads
+ self.shift = shift
+ self.rpb = RelativePositionBias(window_size, num_heads, device=device)
+
+ def forward(self, x: Tensor):
+ B, num_windows, N, _C = x.shape
+ assert _C % (3 * self.num_heads) == 0
+ C = _C // 3
+ x = torch.reshape(x, (B * num_windows, N, 3, self.num_heads, C // self.num_heads))
+ q, k, v = x.permute(2, 0, 3, 1, 4)
+
+ attn_mask = self.rpb()
+ if self.shift:
+ mask = get_attn_mask(isqrt(num_windows * (self.window_size**2)), self.window_size, x.device)
+ mask = mask.reshape(1, num_windows, 1, N, N)
+ mask = mask.expand(B, -1, self.num_heads, -1, -1)
+ attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N)
+
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
+ x = x.transpose(1, 2).reshape(B, num_windows, N, C)
+ return x
+
+
+class WindowAttention(fl.Chain):
+ """
+ Window-based Multi-head Self-Attenion (W-MSA), optionally shifted (SW-MSA).
+
+ It has a trainable relative position bias (RelativePositionBias).
+
+ The input projection is stored as a single Linear for q, k and v.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ window_size: int,
+ num_heads: int,
+ shift: bool = False,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.Linear(dim, dim * 3, bias=True, device=device),
+ WindowSDPA(dim, window_size, num_heads, shift, device=device),
+ fl.Linear(dim, dim, device=device),
+ )
+
+
+class SwinTransformerBlock(fl.Chain):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ window_size: int = 7,
+ shift_size: int = 0,
+ mlp_ratio: float = 4.0,
+ device: Device | None = None,
+ ):
+ assert 0 <= shift_size < window_size, "shift_size must in [0, window_size["
+
+ super().__init__(
+ fl.Residual(
+ fl.LayerNorm(dim, device=device),
+ SquareUnflatten(1),
+ StatefulPad(context="padding", key="sizes", step=window_size),
+ Roll((1, -shift_size), (2, -shift_size)),
+ ToWindows(window_size),
+ WindowAttention(
+ dim,
+ window_size=window_size,
+ num_heads=num_heads,
+ shift=shift_size > 0,
+ device=device,
+ ),
+ FromWindows(),
+ Roll((1, shift_size), (2, shift_size)),
+ StatefulUnpad(context="padding", key="sizes"),
+ fl.Flatten(1, 2),
+ ),
+ fl.Residual(
+ fl.LayerNorm(dim, device=device),
+ fl.Linear(dim, int(dim * mlp_ratio), device=device),
+ fl.GeLU(),
+ fl.Linear(int(dim * mlp_ratio), dim, device=device),
+ ),
+ )
+
+ def init_context(self) -> Contexts:
+ return {"padding": {"sizes": []}}
+
+
+class PatchMerging(fl.Chain):
+ def __init__(self, dim: int, device: Device | None = None):
+ super().__init__(
+ SquareUnflatten(1),
+ Pad(2),
+ WindowUnflatten(2, 2),
+ WindowUnflatten(2, 1),
+ fl.Permute(0, 1, 3, 4, 2, 5),
+ fl.Flatten(3),
+ fl.Flatten(1, 2),
+ fl.LayerNorm(4 * dim, device=device),
+ fl.Linear(4 * dim, 2 * dim, bias=False, device=device),
+ )
+
+
+class BasicLayer(fl.Chain):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ num_heads: int,
+ window_size: int = 7,
+ mlp_ratio: float = 4.0,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ device=device,
+ )
+ for i in range(depth)
+ )
+
+
+class PatchEmbedding(fl.Chain):
+ def __init__(
+ self,
+ patch_size: tuple[int, int] = (4, 4),
+ in_chans: int = 3,
+ embedding_dim: int = 96,
+ device: Device | None = None,
+ ):
+ super().__init__(
+ fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device),
+ fl.Flatten(2),
+ fl.Transpose(1, 2),
+ fl.LayerNorm(embedding_dim, device=device),
+ )
+
+
+class SwinTransformer(fl.Chain):
+ """Swin Transformer (arXiv:2103.14030)
+
+ Currently specific to MVANet, only supports square inputs.
+ """
+
+ def __init__(
+ self,
+ patch_size: tuple[int, int] = (4, 4),
+ in_chans: int = 3,
+ embedding_dim: int = 96,
+ depths: list[int] | None = None,
+ num_heads: list[int] | None = None,
+ window_size: int = 7, # image size is 32 * this
+ mlp_ratio: float = 4.0,
+ device: Device | None = None,
+ ):
+ if depths is None:
+ depths = [2, 2, 6, 2]
+
+ if num_heads is None:
+ num_heads = [3, 6, 12, 24]
+
+ self.num_layers = len(depths)
+ assert len(num_heads) == self.num_layers
+
+ super().__init__(
+ PatchEmbedding(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embedding_dim=embedding_dim,
+ device=device,
+ ),
+ fl.Passthrough(
+ fl.Transpose(1, 2),
+ SquareUnflatten(2),
+ fl.SetContext("swin", "outputs", callback=lambda t, x: t.append(x)),
+ ),
+ *(
+ fl.Chain(
+ BasicLayer(
+ dim=int(embedding_dim * 2**i),
+ depth=depths[i],
+ num_heads=num_heads[i],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ device=device,
+ ),
+ fl.Passthrough(
+ fl.LayerNorm(int(embedding_dim * 2**i), device=device),
+ fl.Transpose(1, 2),
+ SquareUnflatten(2),
+ fl.SetContext("swin", "outputs", callback=lambda t, x: t.insert(0, x)),
+ ),
+ PatchMerging(dim=int(embedding_dim * 2**i), device=device)
+ if i < self.num_layers - 1
+ else fl.UseContext("swin", "outputs").compose(lambda t: tuple(t)),
+ )
+ for i in range(self.num_layers)
+ ),
+ )
+
+ def init_context(self) -> Contexts:
+ return {"swin": {"outputs": []}}
diff --git a/tests/e2e/test_mvanet.py b/tests/e2e/test_mvanet.py
new file mode 100644
index 0000000..08a47ad
--- /dev/null
+++ b/tests/e2e/test_mvanet.py
@@ -0,0 +1,59 @@
+from pathlib import Path
+from warnings import warn
+
+import pytest
+import torch
+from PIL import Image
+from tests.utils import ensure_similar_images
+
+from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image
+from refiners.foundationals.swin.mvanet import MVANet
+
+
+def _img_open(path: Path) -> Image.Image:
+ return Image.open(path) # type: ignore
+
+
+@pytest.fixture(scope="module")
+def ref_path(test_e2e_path: Path) -> Path:
+ return test_e2e_path / "test_mvanet_ref"
+
+
+@pytest.fixture(scope="module")
+def ref_cactus(ref_path: Path) -> Image.Image:
+ return _img_open(ref_path / "cactus.png").convert("RGB")
+
+
+@pytest.fixture
+def expected_cactus_mask(ref_path: Path) -> Image.Image:
+ return _img_open(ref_path / "expected_cactus_mask.png")
+
+
+@pytest.fixture(scope="module")
+def mvanet_weights(test_weights_path: Path) -> Path:
+ weights = test_weights_path / "mvanet" / "mvanet.safetensors"
+ if not weights.is_file():
+ warn(f"could not find weights at {test_weights_path}, skipping")
+ pytest.skip(allow_module_level=True)
+ return weights
+
+
+@pytest.fixture
+def mvanet_model(mvanet_weights: Path, test_device: torch.device) -> MVANet:
+ model = MVANet(device=test_device).eval() # .eval() is important!
+ model.load_from_safetensors(mvanet_weights)
+ return model
+
+
+@no_grad()
+def test_mvanet(
+ mvanet_model: MVANet,
+ ref_cactus: Image.Image,
+ expected_cactus_mask: Image.Image,
+ test_device: torch.device,
+):
+ in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
+ in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
+ prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid()
+ cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
+ ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
diff --git a/tests/e2e/test_mvanet_ref/README.md b/tests/e2e/test_mvanet_ref/README.md
new file mode 100644
index 0000000..687f9f8
--- /dev/null
+++ b/tests/e2e/test_mvanet_ref/README.md
@@ -0,0 +1,3 @@
+`cactus.png` is cropped from this image: https://www.freepik.com/free-photo/laptop-notebook-pen-coffee-cup-plants-wooden-desk_269339828.htm
+
+`expected_cactus_mask.png` has been generated using the [official MVANet codebase](https://github.com/qianyu-dlut/MVANet) and weights.
diff --git a/tests/e2e/test_mvanet_ref/cactus.png b/tests/e2e/test_mvanet_ref/cactus.png
new file mode 100644
index 0000000..b917acb
Binary files /dev/null and b/tests/e2e/test_mvanet_ref/cactus.png differ
diff --git a/tests/e2e/test_mvanet_ref/expected_cactus_mask.png b/tests/e2e/test_mvanet_ref/expected_cactus_mask.png
new file mode 100644
index 0000000..ac7fd77
Binary files /dev/null and b/tests/e2e/test_mvanet_ref/expected_cactus_mask.png differ
diff --git a/typings/gdown/__init__.pyi b/typings/gdown/__init__.pyi
new file mode 100644
index 0000000..01d0b52
--- /dev/null
+++ b/typings/gdown/__init__.pyi
@@ -0,0 +1 @@
+def download(id: str, output: str, quiet: bool = False) -> str: ...