diff --git a/docs/reference/SUMMARY.md b/docs/reference/SUMMARY.md index 8338a24..76475d6 100644 --- a/docs/reference/SUMMARY.md +++ b/docs/reference/SUMMARY.md @@ -9,4 +9,4 @@ * [ DINOv2](foundationals/dinov2.md) * [ Latent Diffusion](foundationals/latent_diffusion.md) * [ Segment Anything](foundationals/segment_anything.md) - + * [ Swin Transformers](foundationals/swin.md) diff --git a/docs/reference/foundationals/swin.md b/docs/reference/foundationals/swin.md new file mode 100644 index 0000000..02c39b3 --- /dev/null +++ b/docs/reference/foundationals/swin.md @@ -0,0 +1,2 @@ +::: refiners.foundationals.swin.swin_transformer +::: refiners.foundationals.swin.mvanet diff --git a/pyproject.toml b/pyproject.toml index c3962e8..ee6909f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ conversion = [ "segment-anything-py>=1.0", "requests>=2.26.0", "tqdm>=4.62.3", + "gdown>=5.2.0", ] doc = [ # required by mkdocs to format the signatures diff --git a/requirements.lock b/requirements.lock index a56a5a5..f9ec260 100644 --- a/requirements.lock +++ b/requirements.lock @@ -31,6 +31,8 @@ babel==2.15.0 # via mkdocs-material backports-strenum==1.3.1 # via griffe +beautifulsoup4==4.12.3 + # via gdown bitsandbytes==0.43.3 # via refiners black==24.4.2 @@ -70,6 +72,7 @@ docker-pycreds==0.4.0 filelock==3.15.4 # via datasets # via diffusers + # via gdown # via huggingface-hub # via torch # via transformers @@ -85,6 +88,8 @@ fsspec==2024.5.0 # via torch future==1.0.0 # via neptune +gdown==5.2.0 + # via refiners ghp-import==2.1.0 # via mkdocs gitdb==4.0.11 @@ -274,6 +279,8 @@ pyjwt==2.9.0 pymdown-extensions==10.9 # via mkdocs-material # via mkdocstrings +pysocks==1.7.1 + # via requests python-dateutil==2.9.0.post0 # via arrow # via botocore @@ -311,6 +318,7 @@ requests==2.32.3 # via bravado-core # via datasets # via diffusers + # via gdown # via huggingface-hub # via mkdocs-material # via neptune @@ -356,6 +364,8 @@ six==1.16.0 # via rfc3339-validator smmap==5.0.1 # via gitdb +soupsieve==2.6 + # via beautifulsoup4 swagger-spec-validator==3.0.4 # via bravado-core # via neptune @@ -383,6 +393,7 @@ torchvision==0.19.0 # via timm tqdm==4.66.4 # via datasets + # via gdown # via huggingface-hub # via refiners # via transformers diff --git a/scripts/conversion/convert_mvanet.py b/scripts/conversion/convert_mvanet.py new file mode 100644 index 0000000..e5b30e7 --- /dev/null +++ b/scripts/conversion/convert_mvanet.py @@ -0,0 +1,40 @@ +import argparse +from pathlib import Path + +from refiners.fluxion.utils import load_tensors, save_to_safetensors +from refiners.foundationals.swin.mvanet.converter import convert_weights + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--from", + type=str, + required=True, + dest="source_path", + help="A MVANet checkpoint. One can be found at https://github.com/qianyu-dlut/MVANet", + ) + parser.add_argument( + "--to", + type=str, + dest="output_path", + default=None, + help=( + "Path to save the converted model. If not specified, the output path will be the source path with the" + " extension changed to .safetensors." + ), + ) + parser.add_argument("--half", action="store_true", dest="half") + args = parser.parse_args() + + src_weights = load_tensors(args.source_path) + weights = convert_weights(src_weights) + if args.half: + weights = {key: value.half() for key, value in weights.items()} + if args.output_path is None: + args.output_path = f"{Path(args.source_path).stem}.safetensors" + save_to_safetensors(path=args.output_path, tensors=weights) + + +if __name__ == "__main__": + main() diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index c20447c..6f27a67 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -11,6 +11,7 @@ import subprocess import sys from urllib.parse import urlparse +import gdown import requests from tqdm import tqdm @@ -446,6 +447,25 @@ def download_ic_light(): ) +def download_mvanet(): + fn = "Model_80.pth" + dest_folder = os.path.join(test_weights_dir, "mvanet") + dest_filename = os.path.join(dest_folder, fn) + + if os.environ.get("DRY_RUN") == "1": + return + + if os.path.exists(dest_filename): + print(f"✖️ ️ Skipping previously downloaded mvanet/{fn}") + else: + os.makedirs(dest_folder, exist_ok=True) + print(f"🔽 Downloading mvanet/{fn} => '{rel(dest_filename)}'", end="\n") + gdown.download(id="1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv", output=dest_filename, quiet=True) + print(f"{previous_line}✅ Downloaded mvanet/{fn} => '{rel(dest_filename)}' ") + + check_hash(dest_filename, "b915d492") + + def printg(msg: str): """print in green color""" print("\033[92m" + msg + "\033[0m") @@ -808,6 +828,16 @@ def convert_ic_light(): ) +def convert_mvanet(): + run_conversion_script( + "convert_mvanet.py", + "tests/weights/mvanet/Model_80.pth", + "tests/weights/mvanet/mvanet.safetensors", + half=True, + expected_hash="bf9ae4cb", + ) + + def download_all(): print(f"\nAll weights will be downloaded to {test_weights_dir}\n") download_sd15("runwayml/stable-diffusion-v1-5") @@ -830,6 +860,7 @@ def download_all(): download_sdxl_lightning_base() download_sdxl_lightning_lora() download_ic_light() + download_mvanet() def convert_all(): @@ -850,6 +881,7 @@ def convert_all(): convert_lcm_base() convert_sdxl_lightning_base() convert_ic_light() + convert_mvanet() def main(): diff --git a/src/refiners/foundationals/swin/__init__.py b/src/refiners/foundationals/swin/__init__.py new file mode 100644 index 0000000..51cfb22 --- /dev/null +++ b/src/refiners/foundationals/swin/__init__.py @@ -0,0 +1,3 @@ +from .swin_transformer import SwinTransformer + +__all__ = ["SwinTransformer"] diff --git a/src/refiners/foundationals/swin/mvanet/__init__.py b/src/refiners/foundationals/swin/mvanet/__init__.py new file mode 100644 index 0000000..5c54590 --- /dev/null +++ b/src/refiners/foundationals/swin/mvanet/__init__.py @@ -0,0 +1,3 @@ +from .mvanet import MVANet + +__all__ = ["MVANet"] diff --git a/src/refiners/foundationals/swin/mvanet/converter.py b/src/refiners/foundationals/swin/mvanet/converter.py new file mode 100644 index 0000000..daacd69 --- /dev/null +++ b/src/refiners/foundationals/swin/mvanet/converter.py @@ -0,0 +1,138 @@ +import re + +from torch import Tensor + + +def convert_weights(official_state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + rm_list = [ + # Official weights contains useless keys + # See https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425 + r"multifieldcrossatt.linear[56]", + r"multifieldcrossatt.attention.5", + r"dec_blk\d+\.linear[12]", + r"dec_blk[1234]\.attention\.[4567]", + # We don't need the sideout weights + r"sideout\d+", + ] + state_dict = {k: v for k, v in official_state_dict.items() if not any(re.match(rm, k) for rm in rm_list)} + + keys_map: dict[str, str] = {} + for k in state_dict.keys(): + v: str = k + + def rpfx(s: str, src: str, dst: str) -> str: + if not s.startswith(src): + return s + return s.replace(src, dst, 1) + + # Swin Transformer backbone + + v = rpfx(v, "backbone.patch_embed.proj.", "SwinTransformer.PatchEmbedding.Conv2d.") + v = rpfx(v, "backbone.patch_embed.norm.", "SwinTransformer.PatchEmbedding.LayerNorm.") + + if m := re.match(r"backbone\.layers\.(\d+)\.downsample\.(.*)", v): + s = m.group(2).replace("reduction.", "Linear.").replace("norm.", "LayerNorm.") + v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.PatchMerging.{s}" + + if m := re.match(r"backbone\.layers\.(\d+)\.blocks\.(\d+)\.(.*)", v): + s = m.group(3) + s = s.replace("norm1.", "Residual_1.LayerNorm.") + s = s.replace("norm2.", "Residual_2.LayerNorm.") + + s = s.replace("attn.qkv.", "Residual_1.WindowAttention.Linear_1.") + s = s.replace("attn.proj.", "Residual_1.WindowAttention.Linear_2.") + s = s.replace("attn.relative_position", "Residual_1.WindowAttention.WindowSDPA.rpb.relative_position") + + s = s.replace("mlp.fc", "Residual_2.Linear_") + v = ".".join( + [ + f"SwinTransformer.Chain_{int(m.group(1)) + 1}", + f"BasicLayer.SwinTransformerBlock_{int(m.group(2)) + 1}", + s, + ] + ) + + if m := re.match(r"backbone\.norm(\d+)\.(.*)", v): + v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.Passthrough.LayerNorm.{m.group(2)}" + + # MVANet + + def mclm(s: str, pfx_src: str, pfx_dst: str) -> str: + pca = f"{pfx_dst}Residual.PatchwiseCrossAttention" + s = rpfx(s, f"{pfx_src}linear1.", f"{pfx_dst}FeedForward_1.Linear_1.") + s = rpfx(s, f"{pfx_src}linear2.", f"{pfx_dst}FeedForward_1.Linear_2.") + s = rpfx(s, f"{pfx_src}linear3.", f"{pfx_dst}FeedForward_2.Linear_1.") + s = rpfx(s, f"{pfx_src}linear4.", f"{pfx_dst}FeedForward_2.Linear_2.") + s = rpfx(s, f"{pfx_src}norm1.", f"{pfx_dst}LayerNorm_1.") + s = rpfx(s, f"{pfx_src}norm2.", f"{pfx_dst}LayerNorm_2.") + s = rpfx(s, f"{pfx_src}attention.0.", f"{pfx_dst}GlobalAttention.Sum.Chain.MultiheadAttention.") + s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.") + s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.") + s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.") + s = rpfx(s, f"{pfx_src}attention.4.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.") + return s + + def mcrm(s: str, pfx_src: str, pfx_dst: str) -> str: + # Note: there are no linear{1,2}, see https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425 + tca = f"{pfx_dst}Parallel_3.TiledCrossAttention" + pca = f"{tca}.Sum.Chain_2.PatchwiseCrossAttention" + s = rpfx(s, f"{pfx_src}linear3.", f"{tca}.FeedForward.Linear_1.") + s = rpfx(s, f"{pfx_src}linear4.", f"{tca}.FeedForward.Linear_2.") + s = rpfx(s, f"{pfx_src}norm1.", f"{tca}.LayerNorm_1.") + s = rpfx(s, f"{pfx_src}norm2.", f"{tca}.LayerNorm_2.") + s = rpfx(s, f"{pfx_src}attention.0.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.") + s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.") + s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.") + s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.") + s = rpfx(s, f"{pfx_src}sal_conv.", f"{pfx_dst}Parallel_2.Multiply.Chain.Conv2d.") + return s + + def cbr(s: str, pfx_src: str, pfx_dst: str, shift: int = 0) -> str: + s = rpfx(s, f"{pfx_src}{shift}.", f"{pfx_dst}Conv2d.") + s = rpfx(s, f"{pfx_src}{shift + 1}.", f"{pfx_dst}BatchNorm2d.") + s = rpfx(s, f"{pfx_src}{shift + 2}.", f"{pfx_dst}PReLU.") + return s + + def cbg(s: str, pfx_src: str, pfx_dst: str) -> str: + s = rpfx(s, f"{pfx_src}0.", f"{pfx_dst}Conv2d.") + s = rpfx(s, f"{pfx_src}1.", f"{pfx_dst}BatchNorm2d.") + return s + + v = rpfx(v, "shallow.0.", "ComputeShallow.Conv2d.") + + v = cbr(v, "output1.", "Pyramid.Sum.Chain.CBR.") + v = cbr(v, "output2.", "Pyramid.Sum.PyramidL2.Sum.Chain.CBR.") + v = cbr(v, "output3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.Chain.CBR.") + v = cbr(v, "output4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.Chain.CBR.") + v = cbr(v, "output5.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.CBR.") + + v = cbr(v, "conv1.", "Pyramid.CBR.") + v = cbr(v, "conv2.", "Pyramid.Sum.PyramidL2.CBR.") + v = cbr(v, "conv3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.CBR.") + v = cbr(v, "conv4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.CBR.") + + v = mclm(v, "multifieldcrossatt.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.MCLM.") + + v = mcrm(v, "dec_blk1.", "Pyramid.MCRM.") + v = mcrm(v, "dec_blk2.", "Pyramid.Sum.PyramidL2.MCRM.") + v = mcrm(v, "dec_blk3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.MCRM.") + v = mcrm(v, "dec_blk4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.MCRM.") + + v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_1.") + v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_2.", shift=3) + + v = rpfx(v, "insmask_head.6.", "RearrangeMultiView.Chain.Conv2d.") + + v = cbg(v, "upsample1.", "ShallowUpscaler.Sum_2.Chain_1.CBG.") + v = cbg(v, "upsample2.", "ShallowUpscaler.CBG.") + + v = rpfx(v, "output.0.", "Conv2d.") + + if v != k: + keys_map[k] = v + + for key, new_key in keys_map.items(): + state_dict[new_key] = state_dict[key] + state_dict.pop(key) + + return state_dict diff --git a/src/refiners/foundationals/swin/mvanet/mclm.py b/src/refiners/foundationals/swin/mvanet/mclm.py new file mode 100644 index 0000000..041308a --- /dev/null +++ b/src/refiners/foundationals/swin/mvanet/mclm.py @@ -0,0 +1,211 @@ +# Multi-View Complementary Localization + +import math + +import torch +from torch import Tensor, device as Device + +import refiners.fluxion.layers as fl +from refiners.fluxion.context import Contexts + +from .utils import FeedForward, MultiheadAttention, MultiPool, PatchMerge, PatchwiseCrossAttention, Unflatten + + +class PerPixel(fl.Chain): + """(B, C, H, W) -> H*W, B, C""" + + def __init__(self): + super().__init__( + fl.Permute(2, 3, 0, 1), + fl.Flatten(0, 1), + ) + + +class PositionEmbeddingSine(fl.Module): + """ + Non-trainable position embedding, originally from https://github.com/facebookresearch/detr + """ + + def __init__(self, num_pos_feats: int, device: Device | None = None): + super().__init__() + self.device = device + temperature = 10000 + self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32, device=self.device) + self.dim_t = temperature ** (2 * (self.dim_t // 2) / num_pos_feats) + + def __call__(self, h: int, w: int) -> Tensor: + mask = torch.ones([1, h, w, 1], dtype=torch.bool, device=self.device) + y_embed = mask.cumsum(dim=1, dtype=torch.float32) + x_embed = mask.cumsum(dim=2, dtype=torch.float32) + + eps, scale = 1e-6, 2 * math.pi + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * scale + + pos_x = x_embed / self.dim_t + pos_y = y_embed / self.dim_t + + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + return torch.cat((pos_y, pos_x), dim=3).permute(1, 2, 0, 3).flatten(0, 1) + + +class MultiPoolPos(fl.Module): + def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine): + super().__init__() + self.pool_ratios = pool_ratios + self.positional_embedding = positional_embedding + + def forward(self, *args: int) -> Tensor: + h, w = args + return torch.cat([self.positional_embedding(h // ratio, w // ratio) for ratio in self.pool_ratios]) + + +class Repeat(fl.Module): + def __init__(self, dim: int = 0): + self.dim = dim + super().__init__() + + def forward(self, x: Tensor, n: int) -> Tensor: + return torch.repeat_interleave(x, n, dim=self.dim) + + +class _MHA_Arg(fl.Sum): + def __init__(self, offset: int): + self.offset = offset + super().__init__( + fl.GetArg(offset), # value + fl.Chain( + fl.Parallel( + fl.GetArg(self.offset + 1), # position embedding + fl.Lambda(self._batch_size), + ), + Repeat(1), + ), + ) + + def _batch_size(self, *args: Tensor) -> int: + return args[self.offset].size(1) + + +class GlobalAttention(fl.Chain): + # Input must be a 4-tuple: (global, global pos. emb, pools, pools pos. emb.) + def __init__( + self, + emb_dim: int, + num_heads: int = 1, + device: Device | None = None, + ): + super().__init__( + fl.Sum( + fl.GetArg(0), # global + fl.Chain( + fl.Parallel( + _MHA_Arg(0), # Q: global + pos. emb + _MHA_Arg(2), # K: pools + pos. emb + fl.GetArg(2), # V: pools + ), + MultiheadAttention(emb_dim, num_heads, device=device), + ), + ), + ) + + +class MCLM(fl.Chain): + """Multi-View Complementary Localization Module + Inputs: + tensor: (b, 5, e, h, h) + Outputs: + tensor: (b, 5, e, h, h) + """ + + def __init__( + self, + emb_dim: int, + num_heads: int = 1, + pool_ratios: list[int] | None = None, + device: Device | None = None, + ): + if pool_ratios is None: + pool_ratios = [2, 8, 16] + + positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device) + + # LayerNorms in MCLM share their weights. + + ln1 = fl.LayerNorm(emb_dim, device=device) + ln2 = fl.LayerNorm(emb_dim, device=device) + + def proxy(m: fl.Module) -> fl.Module: + def f(x: Tensor) -> Tensor: + return m(x) + + return fl.Lambda(f) + + super().__init__( + fl.Parallel( + fl.Chain( # global + fl.Slicing(dim=1, start=4), + fl.Squeeze(1), + fl.Parallel( + PerPixel(), # glb + fl.Chain( # g_pos + fl.Lambda(lambda x: x.shape[-2:]), # type: ignore + positional_embedding, + ), + ), + ), + fl.Chain( # local + fl.Slicing(dim=1, end=4), + fl.SetContext("mclm", "local"), + PatchMerge(), + fl.Parallel( + fl.Chain( # pool + MultiPool(pool_ratios), + fl.Squeeze(0), + ), + fl.Chain( # pool_pos + fl.Lambda(lambda x: x.shape[-2:]), # type: ignore + MultiPoolPos(pool_ratios, positional_embedding), + ), + ), + ), + ), + fl.Lambda(lambda t1, t2: (*t1, *t2)), # type: ignore + GlobalAttention(emb_dim, num_heads, device=device), + ln1, + FeedForward(emb_dim, device=device), + ln2, + fl.SetContext("mclm", "global"), + fl.UseContext("mclm", "local"), + fl.Flatten(-2, -1), + fl.Permute(1, 3, 0, 2), + fl.Residual( + fl.Parallel( + fl.Identity(), + fl.Chain( + fl.UseContext("mclm", "global"), + Unflatten(0, (2, 8, 2, 8)), # 2, h/2, 2, h/2 + fl.Permute(0, 2, 1, 3, 4, 5), + fl.Flatten(0, 1), + fl.Flatten(1, 2), + ), + ), + PatchwiseCrossAttention(emb_dim, num_heads, device=device), + ), + proxy(ln1), + FeedForward(emb_dim, device=device), + proxy(ln2), + fl.Concatenate( + fl.Identity(), + fl.Chain( + fl.UseContext("mclm", "global"), + fl.Unsqueeze(0), + ), + ), + Unflatten(1, (16, 16)), # h, h + fl.Permute(3, 0, 4, 1, 2), + ) + + def init_context(self) -> Contexts: + return {"mclm": {"global": None, "local": None}} diff --git a/src/refiners/foundationals/swin/mvanet/mcrm.py b/src/refiners/foundationals/swin/mvanet/mcrm.py new file mode 100644 index 0000000..01311da --- /dev/null +++ b/src/refiners/foundationals/swin/mvanet/mcrm.py @@ -0,0 +1,119 @@ +# Multi-View Complementary Refinement + +import torch +from torch import Tensor, device as Device + +import refiners.fluxion.layers as fl + +from .utils import FeedForward, Interpolate, MultiPool, PatchMerge, PatchSplit, PatchwiseCrossAttention, Unflatten + + +class Multiply(fl.Chain): + def __init__(self, o1: fl.Module, o2: fl.Module) -> None: + super().__init__(o1, o2) + + def forward(self, *args: Tensor) -> Tensor: + return torch.mul(self[0](*args), self[1](*args)) + + +class TiledCrossAttention(fl.Chain): + def __init__( + self, + emb_dim: int, + dim: int, + num_heads: int = 1, + pool_ratios: list[int] | None = None, + device: Device | None = None, + ): + # Input must be a 4-tuple: (local, global) + + if pool_ratios is None: + pool_ratios = [1, 2, 4] + + super().__init__( + fl.Distribute( + fl.Chain( # local + fl.Flatten(-2, -1), + fl.Permute(1, 3, 0, 2), + ), + fl.Chain( # global + PatchSplit(), + fl.Squeeze(0), + MultiPool(pool_ratios), + ), + ), + fl.Sum( + fl.Chain( + fl.GetArg(0), + fl.Permute(2, 1, 0, 3), + ), + fl.Chain( + PatchwiseCrossAttention(emb_dim, num_heads, device=device), + fl.Permute(2, 1, 0, 3), + ), + ), + fl.LayerNorm(emb_dim, device=device), + FeedForward(emb_dim, device=device), + fl.LayerNorm(emb_dim, device=device), + fl.Permute(0, 2, 3, 1), + Unflatten(-1, (dim, dim)), + ) + + +class MCRM(fl.Chain): + """Multi-View Complementary Refinement""" + + def __init__( + self, + emb_dim: int, + size: int, + num_heads: int = 1, + pool_ratios: list[int] | None = None, + device: Device | None = None, + ): + if pool_ratios is None: + pool_ratios = [1, 2, 4] + + super().__init__( + fl.Parallel( + fl.Chain( # local + fl.Slicing(dim=1, end=4), + ), + fl.Chain( # global + fl.Slicing(dim=1, start=4), + fl.Squeeze(1), + ), + ), + fl.Parallel( + Multiply( + fl.GetArg(0), + fl.Chain( + fl.GetArg(1), + fl.Conv2d(emb_dim, 1, 1, device=device), + fl.Sigmoid(), + Interpolate((size * 2, size * 2), "nearest"), + PatchSplit(), + ), + ), + fl.GetArg(1), + ), + fl.Parallel( + TiledCrossAttention(emb_dim, size, num_heads, pool_ratios, device=device), + fl.GetArg(1), + ), + fl.Concatenate( + fl.GetArg(0), + fl.Chain( + fl.Sum( + fl.GetArg(1), + fl.Chain( + fl.GetArg(0), + PatchMerge(), + Interpolate((size, size), "nearest"), + ), + ), + fl.Unsqueeze(1), + ), + dim=1, + ), + ) diff --git a/src/refiners/foundationals/swin/mvanet/mvanet.py b/src/refiners/foundationals/swin/mvanet/mvanet.py new file mode 100644 index 0000000..4410eb6 --- /dev/null +++ b/src/refiners/foundationals/swin/mvanet/mvanet.py @@ -0,0 +1,337 @@ +# Multi-View Aggregation Network (arXiv:2404.07445) + +from torch import device as Device + +import refiners.fluxion.layers as fl +from refiners.fluxion.context import Contexts +from refiners.foundationals.swin.swin_transformer import SwinTransformer + +from .mclm import MCLM # Multi-View Complementary Localization +from .mcrm import MCRM # Multi-View Complementary Refinement +from .utils import BatchNorm2d, Interpolate, PatchMerge, PatchSplit, PReLU, Rescale, Unflatten + + +class CBG(fl.Chain): + """(C)onvolution + (B)atchNorm + (G)eLU""" + + def __init__( + self, + in_dim: int, + out_dim: int | None = None, + device: Device | None = None, + ): + out_dim = out_dim or in_dim + super().__init__( + fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device), + BatchNorm2d(out_dim, device=device), + fl.GeLU(), + ) + + +class CBR(fl.Chain): + """(C)onvolution + (B)atchNorm + Parametric (R)eLU""" + + def __init__( + self, + in_dim: int, + out_dim: int | None = None, + device: Device | None = None, + ): + out_dim = out_dim or in_dim + super().__init__( + fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device), + BatchNorm2d(out_dim, device=device), + PReLU(device=device), + ) + + +class SplitMultiView(fl.Chain): + """ + Split a hd tensor into 5 ld views, (5 = 1 global + 4 tiles) + See also the reverse Module [`RearrangeMultiView`][refiners.foundationals.swin.mvanet.RearrangeMultiView] + + Inputs: + single_view (b, c, H, W) + + Outputs: + multi_view (b, 5, c, H/2, W/2) + """ + + def __init__(self): + super().__init__( + fl.Concatenate( + PatchSplit(), # global features + fl.Chain( # local features + Rescale(scale_factor=0.5, mode="bilinear"), + fl.Unsqueeze(1), + ), + dim=1, + ) + ) + + +class ShallowUpscaler(fl.Chain): + """4x Upscaler reusing the image as input to upscale the feature + See [[arXiv:2108.10257] SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/abs/2108.10257) + + Args: + embedding_dim (int): the embedding dimension + + Inputs: + feature (b, E, image_size/4, image_size/4) + + Output: + upscaled tensor (b, E, image_size, image_size) + """ + + def __init__( + self, + embedding_dim: int = 128, + device: Device | None = None, + ): + super().__init__( + fl.Sum( + fl.Identity(), + fl.Chain( + fl.UseContext("mvanet", "shallow"), + Interpolate((256, 256)), + ), + ), + fl.Sum( + fl.Chain( + Rescale(2), + CBG(embedding_dim, device=device), + ), + fl.Chain( + fl.UseContext("mvanet", "shallow"), + Interpolate((512, 512)), + ), + ), + Rescale(2), + CBG(embedding_dim, device=device), + ) + + +class PyramidL5(fl.Chain): + def __init__( + self, + embedding_dim: int = 128, + device: Device | None = None, + ): + super().__init__( + fl.GetArg(0), # output5 + fl.Flatten(0, 1), + CBR(1024, embedding_dim, device=device), + Unflatten(0, (-1, 5)), + MCLM(embedding_dim, device=device), + fl.Flatten(0, 1), + Interpolate((32, 32)), + ) + + +class PyramidL4(fl.Chain): + def __init__( + self, + embedding_dim: int = 128, + device: Device | None = None, + ): + super().__init__( + fl.Sum( + PyramidL5(embedding_dim=embedding_dim, device=device), + fl.Chain( + fl.GetArg(1), + fl.Flatten(0, 1), + CBR(512, embedding_dim, device=device), # output4 + Unflatten(0, (-1, 5)), + ), + ), + MCRM(embedding_dim, 32, device=device), # dec_blk4 + fl.Flatten(0, 1), + CBR(embedding_dim, device=device), # conv4 + Interpolate((64, 64)), + ) + + +class PyramidL3(fl.Chain): + def __init__( + self, + embedding_dim: int = 128, + device: Device | None = None, + ): + super().__init__( + fl.Sum( + PyramidL4(embedding_dim=embedding_dim, device=device), + fl.Chain( + fl.GetArg(2), + fl.Flatten(0, 1), + CBR(256, embedding_dim, device=device), # output3 + Unflatten(0, (-1, 5)), + ), + ), + MCRM(embedding_dim, 64, device=device), # dec_blk3 + fl.Flatten(0, 1), + CBR(embedding_dim, device=device), # conv3 + Interpolate((128, 128)), + ) + + +class PyramidL2(fl.Chain): + def __init__( + self, + embedding_dim: int = 128, + device: Device | None = None, + ): + embedding_dim = 128 + super().__init__( + fl.Sum( + PyramidL3(embedding_dim=embedding_dim, device=device), + fl.Chain( + fl.GetArg(3), + fl.Flatten(0, 1), + CBR(128, embedding_dim, device=device), # output2 + Unflatten(0, (-1, 5)), + ), + ), + MCRM(embedding_dim, 128, device=device), # dec_blk2 + fl.Flatten(0, 1), + CBR(embedding_dim, device=device), # conv2 + Interpolate((128, 128)), + ) + + +class Pyramid(fl.Chain): + """ + Recursive Pyramidal Network calling MCLM and MCRM blocks + + It acts as a FPN (Feature Pyramid Network) Neck for MVANet + see [[arXiv:1612.03144] Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144) + + Inputs: + features: a pyramid of N = 5 tensors + shapes are (b, 5, E_{0}, S_{0}, S_{0}), ..., (b, 5, E_{1}, S_{i}, S_{i}), ..., (b, 5, E_{N-1}, S_{N-1}, S_{N-1}) + with S_{i} = S_{i-1} or S_{i} = 2*S_{i-1} for 0 < i < N + + Outputs: + output (b, 5, E, S_{N-1}, S_{N-1}) + """ + + def __init__( + self, + embedding_dim: int = 128, + device: Device | None = None, + ): + super().__init__( + fl.Sum( + PyramidL2(embedding_dim=embedding_dim, device=device), + fl.Chain( + fl.GetArg(4), + fl.Flatten(0, 1), + CBR(128, embedding_dim, device=device), # output1 + Unflatten(0, (-1, 5)), + ), + ), + MCRM(embedding_dim, 128, device=device), # dec_blk1 + fl.Flatten(0, 1), + CBR(embedding_dim, device=device), # conv1 + Unflatten(0, (-1, 5)), + ) + + +class RearrangeMultiView(fl.Chain): + """ + Inputs: + multi_view (b, 5, E, H, W) + + Outputs: + single_view (b, E, H*2, W*2) + + Fusion a multi view tensor into a single view tensor, using convolutions + See also the reverse Module [`SplitMultiView`][refiners.foundationals.swin.mvanet.SplitMultiView] + """ + + def __init__( + self, + embedding_dim: int = 128, + device: Device | None = None, + ): + super().__init__( + fl.Sum( + fl.Chain( # local features + fl.Slicing(dim=1, end=4), + PatchMerge(), + ), + fl.Chain( # global feature + fl.Slicing(dim=1, start=4), + fl.Squeeze(1), + Interpolate((256, 256)), + ), + ), + fl.Chain( # conv head + CBR(embedding_dim, 384, device=device), + CBR(384, device=device), + fl.Conv2d(384, embedding_dim, kernel_size=3, padding=1, device=device), + ), + ) + + +class ComputeShallow(fl.Passthrough): + def __init__( + self, + embedding_dim: int = 128, + device: Device | None = None, + ): + super().__init__( + fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device), + fl.SetContext("mvanet", "shallow"), + ) + + +class MVANet(fl.Chain): + """Multi-view Aggregation Network for Dichotomous Image Segmentation + + See [[arXiv:2404.07445] Multi-view Aggregation Network for Dichotomous Image Segmentation](https://arxiv.org/abs/2404.07445) for more details. + + Args: + embedding_dim (int): embedding dimension + n_logits (int): the number of output logits (default to 1) + 1 logit is used for alpha matting/foreground-background segmentation/sod segmentation + depths (list[int]): see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer] + num_heads (list[int]): see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer] + window_size (int): default to 12, see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer] + device (Device | None): the device to use + """ + + def __init__( + self, + embedding_dim: int = 128, + n_logits: int = 1, + depths: list[int] | None = None, + num_heads: list[int] | None = None, + window_size: int = 12, + device: Device | None = None, + ): + if depths is None: + depths = [2, 2, 18, 2] + if num_heads is None: + num_heads = [4, 8, 16, 32] + + super().__init__( + ComputeShallow(embedding_dim=embedding_dim, device=device), + SplitMultiView(), + fl.Flatten(0, 1), + SwinTransformer( + embedding_dim=embedding_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + device=device, + ), + fl.Distribute(*(Unflatten(0, (-1, 5)) for _ in range(5))), + Pyramid(embedding_dim=embedding_dim, device=device), + RearrangeMultiView(embedding_dim=embedding_dim, device=device), + ShallowUpscaler(embedding_dim, device=device), + fl.Conv2d(embedding_dim, n_logits, kernel_size=3, padding=1, device=device), + ) + + def init_context(self) -> Contexts: + return {"mvanet": {"shallow": None}} diff --git a/src/refiners/foundationals/swin/mvanet/utils.py b/src/refiners/foundationals/swin/mvanet/utils.py new file mode 100644 index 0000000..363b55a --- /dev/null +++ b/src/refiners/foundationals/swin/mvanet/utils.py @@ -0,0 +1,173 @@ +import torch +from torch import Size, Tensor +from torch.nn.functional import ( + adaptive_avg_pool2d, + interpolate, # type: ignore +) + +import refiners.fluxion.layers as fl + + +class Unflatten(fl.Module): + def __init__(self, dim: int, sizes: tuple[int, ...]) -> None: + super().__init__() + self.dim = dim + self.sizes = Size(sizes) + + def forward(self, x: Tensor) -> Tensor: + return torch.unflatten(input=x, dim=self.dim, sizes=self.sizes) + + +class Interpolate(fl.Module): + def __init__(self, size: tuple[int, ...], mode: str = "bilinear"): + super().__init__() + self.size = Size(size) + self.mode = mode + + def forward(self, x: Tensor) -> Tensor: + return interpolate(x, size=self.size, mode=self.mode) # type: ignore + + +class Rescale(fl.Module): + def __init__(self, scale_factor: float, mode: str = "nearest"): + super().__init__() + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, x: Tensor) -> Tensor: + return interpolate(x, scale_factor=self.scale_factor, mode=self.mode) # type: ignore + + +class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule): + def __init__(self, num_features: int, device: torch.device | None = None): + super().__init__(num_features=num_features, device=device) # type: ignore + + +class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation): + def __init__(self, device: torch.device | None = None): + super().__init__(device=device) # type: ignore + + +class PatchSplit(fl.Chain): + """(B, N, H, W) -> B, 4, N, H/2, W/2""" + + def __init__(self): + super().__init__( + Unflatten(-2, (2, -1)), + Unflatten(-1, (2, -1)), + fl.Permute(0, 2, 4, 1, 3, 5), + fl.Flatten(1, 2), + ) + + +class PatchMerge(fl.Chain): + """B, 4, N, H, W -> (B, N, 2*H, 2*W)""" + + def __init__(self): + super().__init__( + Unflatten(1, (2, 2)), + fl.Permute(0, 3, 1, 4, 2, 5), + fl.Flatten(-2, -1), + fl.Flatten(-3, -2), + ) + + +class FeedForward(fl.Residual): + def __init__(self, emb_dim: int, device: torch.device | None = None) -> None: + super().__init__( + fl.Linear(in_features=emb_dim, out_features=2 * emb_dim, device=device), + fl.ReLU(), + fl.Linear(in_features=2 * emb_dim, out_features=emb_dim, device=device), + ) + + +class _GetArgs(fl.Parallel): + def __init__(self, n: int): + super().__init__( + fl.Chain( + fl.GetArg(0), + fl.Slicing(dim=0, start=n, end=n + 1), + fl.Squeeze(0), + ), + fl.Chain( + fl.GetArg(1), + fl.Slicing(dim=0, start=n, end=n + 1), + fl.Squeeze(0), + ), + fl.Chain( + fl.GetArg(1), + fl.Slicing(dim=0, start=n, end=n + 1), + fl.Squeeze(0), + ), + ) + + +class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule): + def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None): + super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore + + @property + def weight(self) -> Tensor: # type: ignore + return self.in_proj_weight + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # type: ignore + return super().forward(q, k, v)[0] + + +class PatchwiseCrossAttention(fl.Chain): + # Input is 2 tensors of sizes (4, HW, B, C) and (4, HW', B, C), + # output is size (4, HW, B, C). + def __init__( + self, + d_model: int, + num_heads: int, + device: torch.device | None = None, + ): + super().__init__( + fl.Concatenate( + fl.Chain( + _GetArgs(0), + MultiheadAttention(d_model, num_heads, device=device), + ), + fl.Chain( + _GetArgs(1), + MultiheadAttention(d_model, num_heads, device=device), + ), + fl.Chain( + _GetArgs(2), + MultiheadAttention(d_model, num_heads, device=device), + ), + fl.Chain( + _GetArgs(3), + MultiheadAttention(d_model, num_heads, device=device), + ), + ), + Unflatten(0, (4, -1)), + ) + + +class Pool(fl.Module): + def __init__(self, ratio: int) -> None: + super().__init__() + self.ratio = ratio + + def forward(self, x: Tensor) -> Tensor: + b, _, h, w = x.shape + assert h % self.ratio == 0 and w % self.ratio == 0 + r = adaptive_avg_pool2d(x, (h // self.ratio, w // self.ratio)) + return torch.unflatten(r, 0, (b, -1)) + + +class MultiPool(fl.Concatenate): + def __init__(self, pool_ratios: list[int]) -> None: + super().__init__( + *( + fl.Chain( + Pool(pool_ratio), + fl.Flatten(-2, -1), + fl.Permute(0, 3, 1, 2), + ) + for pool_ratio in pool_ratios + ), + dim=1, + ) diff --git a/src/refiners/foundationals/swin/swin_transformer.py b/src/refiners/foundationals/swin/swin_transformer.py new file mode 100644 index 0000000..f1291fe --- /dev/null +++ b/src/refiners/foundationals/swin/swin_transformer.py @@ -0,0 +1,391 @@ +# Swin Transformer (arXiv:2103.14030) +# +# Specific to MVANet, only supports square inputs. +# Originally adapted from the version in MVANet and InSPyReNet (https://github.com/plemeri/InSPyReNet) +# Original implementation by Microsoft at https://github.com/microsoft/Swin-Transformer + +import functools +from math import isqrt + +import torch +from torch import Tensor, device as Device + +import refiners.fluxion.layers as fl +from refiners.fluxion.context import Contexts + + +def to_windows(x: Tensor, window_size: int) -> Tensor: + B, H, W, C = x.shape + assert W == H and H % window_size == 0 + x = x.reshape(B, H // window_size, window_size, W // window_size, window_size, C) + return x.permute(0, 1, 3, 2, 4, 5).reshape(B, -1, window_size * window_size, C) + + +class ToWindows(fl.Module): + def __init__(self, window_size: int): + super().__init__() + self.window_size = window_size + + def forward(self, x: Tensor) -> Tensor: + return to_windows(x, self.window_size) + + +class FromWindows(fl.Module): + def forward(self, x: Tensor) -> Tensor: + B, num_windows, window_size_2, C = x.shape + window_size = isqrt(window_size_2) + H = isqrt(num_windows * window_size_2) + x = x.reshape(B, H // window_size, H // window_size, window_size, window_size, C) + return x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, H, C) + + +@functools.cache +def get_attn_mask(H: int, window_size: int, device: Device | None = None) -> Tensor: + assert H % window_size == 0 + shift_size = window_size // 2 + img_mask = torch.zeros((1, H, H, 1), device=device) + h_slices = ( + slice(0, -window_size), + slice(-window_size, -shift_size), + slice(-shift_size, None), + ) + w_slices = ( + slice(0, -window_size), + slice(-window_size, -shift_size), + slice(-shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = to_windows(img_mask, window_size).squeeze() # B, nW, window_size * window_size, [1] + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask.masked_fill_(attn_mask != 0, -100.0).masked_fill_(attn_mask == 0, 0.0) + return attn_mask + + +class Pad(fl.Module): + def __init__(self, step: int): + super().__init__() + self.step = step + + def forward(self, x: Tensor) -> Tensor: + B, H, W, C = x.shape + assert W == H + if H % self.step == 0: + return x + p = self.step * ((H + self.step - 1) // self.step) + padded = torch.zeros(B, p, p, C, device=x.device, dtype=x.dtype) + padded[:, :H, :H, :] = x + return padded + + +class StatefulPad(fl.Chain): + def __init__(self, context: str, key: str, step: int) -> None: + super().__init__( + fl.SetContext(context=context, key=key, callback=self._push), + Pad(step=step), + ) + + def _push(self, sizes: list[int], x: Tensor) -> None: + sizes.append(x.size(1)) + + +class StatefulUnpad(fl.Chain): + def __init__(self, context: str, key: str) -> None: + super().__init__( + fl.Parallel( + fl.Identity(), + fl.UseContext(context=context, key=key).compose(lambda x: x.pop()), + ), + fl.Lambda(self._unpad), + ) + + @staticmethod + def _unpad(x: Tensor, size: int) -> Tensor: + return x[:, :size, :size, :] + + +class SquareUnflatten(fl.Module): + # ..., L^2, ... -> ..., L, L, ... + + def __init__(self, dim: int = 0) -> None: + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + d = isqrt(x.shape[self.dim]) + return torch.unflatten(x, self.dim, (d, d)) + + +class WindowUnflatten(fl.Module): + # ..., H, ... -> ..., H // ws, ws, ... + + def __init__(self, window_size: int, dim: int = 0) -> None: + super().__init__() + self.window_size = window_size + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.dim] % self.window_size == 0 + H = x.shape[self.dim] + return torch.unflatten(x, self.dim, (H // self.window_size, self.window_size)) + + +class Roll(fl.Module): + def __init__(self, *shifts: tuple[int, int]): + super().__init__() + self.shifts = shifts + self._dims = tuple(s[0] for s in shifts) + self._shifts = tuple(s[1] for s in shifts) + + def forward(self, x: Tensor) -> Tensor: + return torch.roll(x, self._shifts, self._dims) + + +class RelativePositionBias(fl.Module): + relative_position_index: Tensor + + def __init__(self, window_size: int, num_heads: int, device: Device | None = None): + super().__init__() + self.relative_position_bias_table = torch.nn.Parameter( + torch.empty( + (2 * window_size - 1) * (2 * window_size - 1), + num_heads, + device=device, + ) + ) + relative_position_index = torch.empty( + window_size**2, + window_size**2, + device=device, + dtype=torch.int64, + ) + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self) -> Tensor: + # Yes, this is a (trainable) constant. + return self.relative_position_bias_table[self.relative_position_index].permute(2, 0, 1).unsqueeze(0) + + +class WindowSDPA(fl.Module): + def __init__( + self, + dim: int, + window_size: int, + num_heads: int, + shift: bool = False, + device: Device | None = None, + ): + super().__init__() + self.window_size = window_size + self.num_heads = num_heads + self.shift = shift + self.rpb = RelativePositionBias(window_size, num_heads, device=device) + + def forward(self, x: Tensor): + B, num_windows, N, _C = x.shape + assert _C % (3 * self.num_heads) == 0 + C = _C // 3 + x = torch.reshape(x, (B * num_windows, N, 3, self.num_heads, C // self.num_heads)) + q, k, v = x.permute(2, 0, 3, 1, 4) + + attn_mask = self.rpb() + if self.shift: + mask = get_attn_mask(isqrt(num_windows * (self.window_size**2)), self.window_size, x.device) + mask = mask.reshape(1, num_windows, 1, N, N) + mask = mask.expand(B, -1, self.num_heads, -1, -1) + attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask) + x = x.transpose(1, 2).reshape(B, num_windows, N, C) + return x + + +class WindowAttention(fl.Chain): + """ + Window-based Multi-head Self-Attenion (W-MSA), optionally shifted (SW-MSA). + + It has a trainable relative position bias (RelativePositionBias). + + The input projection is stored as a single Linear for q, k and v. + """ + + def __init__( + self, + dim: int, + window_size: int, + num_heads: int, + shift: bool = False, + device: Device | None = None, + ): + super().__init__( + fl.Linear(dim, dim * 3, bias=True, device=device), + WindowSDPA(dim, window_size, num_heads, shift, device=device), + fl.Linear(dim, dim, device=device), + ) + + +class SwinTransformerBlock(fl.Chain): + def __init__( + self, + dim: int, + num_heads: int, + window_size: int = 7, + shift_size: int = 0, + mlp_ratio: float = 4.0, + device: Device | None = None, + ): + assert 0 <= shift_size < window_size, "shift_size must in [0, window_size[" + + super().__init__( + fl.Residual( + fl.LayerNorm(dim, device=device), + SquareUnflatten(1), + StatefulPad(context="padding", key="sizes", step=window_size), + Roll((1, -shift_size), (2, -shift_size)), + ToWindows(window_size), + WindowAttention( + dim, + window_size=window_size, + num_heads=num_heads, + shift=shift_size > 0, + device=device, + ), + FromWindows(), + Roll((1, shift_size), (2, shift_size)), + StatefulUnpad(context="padding", key="sizes"), + fl.Flatten(1, 2), + ), + fl.Residual( + fl.LayerNorm(dim, device=device), + fl.Linear(dim, int(dim * mlp_ratio), device=device), + fl.GeLU(), + fl.Linear(int(dim * mlp_ratio), dim, device=device), + ), + ) + + def init_context(self) -> Contexts: + return {"padding": {"sizes": []}} + + +class PatchMerging(fl.Chain): + def __init__(self, dim: int, device: Device | None = None): + super().__init__( + SquareUnflatten(1), + Pad(2), + WindowUnflatten(2, 2), + WindowUnflatten(2, 1), + fl.Permute(0, 1, 3, 4, 2, 5), + fl.Flatten(3), + fl.Flatten(1, 2), + fl.LayerNorm(4 * dim, device=device), + fl.Linear(4 * dim, 2 * dim, bias=False, device=device), + ) + + +class BasicLayer(fl.Chain): + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + window_size: int = 7, + mlp_ratio: float = 4.0, + device: Device | None = None, + ): + super().__init__( + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + device=device, + ) + for i in range(depth) + ) + + +class PatchEmbedding(fl.Chain): + def __init__( + self, + patch_size: tuple[int, int] = (4, 4), + in_chans: int = 3, + embedding_dim: int = 96, + device: Device | None = None, + ): + super().__init__( + fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device), + fl.Flatten(2), + fl.Transpose(1, 2), + fl.LayerNorm(embedding_dim, device=device), + ) + + +class SwinTransformer(fl.Chain): + """Swin Transformer (arXiv:2103.14030) + + Currently specific to MVANet, only supports square inputs. + """ + + def __init__( + self, + patch_size: tuple[int, int] = (4, 4), + in_chans: int = 3, + embedding_dim: int = 96, + depths: list[int] | None = None, + num_heads: list[int] | None = None, + window_size: int = 7, # image size is 32 * this + mlp_ratio: float = 4.0, + device: Device | None = None, + ): + if depths is None: + depths = [2, 2, 6, 2] + + if num_heads is None: + num_heads = [3, 6, 12, 24] + + self.num_layers = len(depths) + assert len(num_heads) == self.num_layers + + super().__init__( + PatchEmbedding( + patch_size=patch_size, + in_chans=in_chans, + embedding_dim=embedding_dim, + device=device, + ), + fl.Passthrough( + fl.Transpose(1, 2), + SquareUnflatten(2), + fl.SetContext("swin", "outputs", callback=lambda t, x: t.append(x)), + ), + *( + fl.Chain( + BasicLayer( + dim=int(embedding_dim * 2**i), + depth=depths[i], + num_heads=num_heads[i], + window_size=window_size, + mlp_ratio=mlp_ratio, + device=device, + ), + fl.Passthrough( + fl.LayerNorm(int(embedding_dim * 2**i), device=device), + fl.Transpose(1, 2), + SquareUnflatten(2), + fl.SetContext("swin", "outputs", callback=lambda t, x: t.insert(0, x)), + ), + PatchMerging(dim=int(embedding_dim * 2**i), device=device) + if i < self.num_layers - 1 + else fl.UseContext("swin", "outputs").compose(lambda t: tuple(t)), + ) + for i in range(self.num_layers) + ), + ) + + def init_context(self) -> Contexts: + return {"swin": {"outputs": []}} diff --git a/tests/e2e/test_mvanet.py b/tests/e2e/test_mvanet.py new file mode 100644 index 0000000..08a47ad --- /dev/null +++ b/tests/e2e/test_mvanet.py @@ -0,0 +1,59 @@ +from pathlib import Path +from warnings import warn + +import pytest +import torch +from PIL import Image +from tests.utils import ensure_similar_images + +from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image +from refiners.foundationals.swin.mvanet import MVANet + + +def _img_open(path: Path) -> Image.Image: + return Image.open(path) # type: ignore + + +@pytest.fixture(scope="module") +def ref_path(test_e2e_path: Path) -> Path: + return test_e2e_path / "test_mvanet_ref" + + +@pytest.fixture(scope="module") +def ref_cactus(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "cactus.png").convert("RGB") + + +@pytest.fixture +def expected_cactus_mask(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_cactus_mask.png") + + +@pytest.fixture(scope="module") +def mvanet_weights(test_weights_path: Path) -> Path: + weights = test_weights_path / "mvanet" / "mvanet.safetensors" + if not weights.is_file(): + warn(f"could not find weights at {test_weights_path}, skipping") + pytest.skip(allow_module_level=True) + return weights + + +@pytest.fixture +def mvanet_model(mvanet_weights: Path, test_device: torch.device) -> MVANet: + model = MVANet(device=test_device).eval() # .eval() is important! + model.load_from_safetensors(mvanet_weights) + return model + + +@no_grad() +def test_mvanet( + mvanet_model: MVANet, + ref_cactus: Image.Image, + expected_cactus_mask: Image.Image, + test_device: torch.device, +): + in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze() + in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0) + prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid() + cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR) + ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB")) diff --git a/tests/e2e/test_mvanet_ref/README.md b/tests/e2e/test_mvanet_ref/README.md new file mode 100644 index 0000000..687f9f8 --- /dev/null +++ b/tests/e2e/test_mvanet_ref/README.md @@ -0,0 +1,3 @@ +`cactus.png` is cropped from this image: https://www.freepik.com/free-photo/laptop-notebook-pen-coffee-cup-plants-wooden-desk_269339828.htm + +`expected_cactus_mask.png` has been generated using the [official MVANet codebase](https://github.com/qianyu-dlut/MVANet) and weights. diff --git a/tests/e2e/test_mvanet_ref/cactus.png b/tests/e2e/test_mvanet_ref/cactus.png new file mode 100644 index 0000000..b917acb Binary files /dev/null and b/tests/e2e/test_mvanet_ref/cactus.png differ diff --git a/tests/e2e/test_mvanet_ref/expected_cactus_mask.png b/tests/e2e/test_mvanet_ref/expected_cactus_mask.png new file mode 100644 index 0000000..ac7fd77 Binary files /dev/null and b/tests/e2e/test_mvanet_ref/expected_cactus_mask.png differ diff --git a/typings/gdown/__init__.pyi b/typings/gdown/__init__.pyi new file mode 100644 index 0000000..01d0b52 --- /dev/null +++ b/typings/gdown/__init__.pyi @@ -0,0 +1 @@ +def download(id: str, output: str, quiet: bool = False) -> str: ...