Add Multi-View Aggregation Network (MVANet)

Co-authored-by: Pierre Colle <piercus@gmail.com>
This commit is contained in:
Pierre Chapuis 2024-08-10 15:17:36 +02:00
parent 58c1cc7cd4
commit 10dfa73a09
19 changed files with 1525 additions and 1 deletions

View file

@ -9,4 +9,4 @@
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> DINOv2](foundationals/dinov2.md) * [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> DINOv2](foundationals/dinov2.md)
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Latent Diffusion](foundationals/latent_diffusion.md) * [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Latent Diffusion](foundationals/latent_diffusion.md)
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Segment Anything](foundationals/segment_anything.md) * [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Segment Anything](foundationals/segment_anything.md)
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Swin Transformers](foundationals/swin.md)

View file

@ -0,0 +1,2 @@
::: refiners.foundationals.swin.swin_transformer
::: refiners.foundationals.swin.mvanet

View file

@ -58,6 +58,7 @@ conversion = [
"segment-anything-py>=1.0", "segment-anything-py>=1.0",
"requests>=2.26.0", "requests>=2.26.0",
"tqdm>=4.62.3", "tqdm>=4.62.3",
"gdown>=5.2.0",
] ]
doc = [ doc = [
# required by mkdocs to format the signatures # required by mkdocs to format the signatures

View file

@ -31,6 +31,8 @@ babel==2.15.0
# via mkdocs-material # via mkdocs-material
backports-strenum==1.3.1 backports-strenum==1.3.1
# via griffe # via griffe
beautifulsoup4==4.12.3
# via gdown
bitsandbytes==0.43.3 bitsandbytes==0.43.3
# via refiners # via refiners
black==24.4.2 black==24.4.2
@ -70,6 +72,7 @@ docker-pycreds==0.4.0
filelock==3.15.4 filelock==3.15.4
# via datasets # via datasets
# via diffusers # via diffusers
# via gdown
# via huggingface-hub # via huggingface-hub
# via torch # via torch
# via transformers # via transformers
@ -85,6 +88,8 @@ fsspec==2024.5.0
# via torch # via torch
future==1.0.0 future==1.0.0
# via neptune # via neptune
gdown==5.2.0
# via refiners
ghp-import==2.1.0 ghp-import==2.1.0
# via mkdocs # via mkdocs
gitdb==4.0.11 gitdb==4.0.11
@ -274,6 +279,8 @@ pyjwt==2.9.0
pymdown-extensions==10.9 pymdown-extensions==10.9
# via mkdocs-material # via mkdocs-material
# via mkdocstrings # via mkdocstrings
pysocks==1.7.1
# via requests
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
# via arrow # via arrow
# via botocore # via botocore
@ -311,6 +318,7 @@ requests==2.32.3
# via bravado-core # via bravado-core
# via datasets # via datasets
# via diffusers # via diffusers
# via gdown
# via huggingface-hub # via huggingface-hub
# via mkdocs-material # via mkdocs-material
# via neptune # via neptune
@ -356,6 +364,8 @@ six==1.16.0
# via rfc3339-validator # via rfc3339-validator
smmap==5.0.1 smmap==5.0.1
# via gitdb # via gitdb
soupsieve==2.6
# via beautifulsoup4
swagger-spec-validator==3.0.4 swagger-spec-validator==3.0.4
# via bravado-core # via bravado-core
# via neptune # via neptune
@ -383,6 +393,7 @@ torchvision==0.19.0
# via timm # via timm
tqdm==4.66.4 tqdm==4.66.4
# via datasets # via datasets
# via gdown
# via huggingface-hub # via huggingface-hub
# via refiners # via refiners
# via transformers # via transformers

View file

@ -0,0 +1,40 @@
import argparse
from pathlib import Path
from refiners.fluxion.utils import load_tensors, save_to_safetensors
from refiners.foundationals.swin.mvanet.converter import convert_weights
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--from",
type=str,
required=True,
dest="source_path",
help="A MVANet checkpoint. One can be found at https://github.com/qianyu-dlut/MVANet",
)
parser.add_argument(
"--to",
type=str,
dest="output_path",
default=None,
help=(
"Path to save the converted model. If not specified, the output path will be the source path with the"
" extension changed to .safetensors."
),
)
parser.add_argument("--half", action="store_true", dest="half")
args = parser.parse_args()
src_weights = load_tensors(args.source_path)
weights = convert_weights(src_weights)
if args.half:
weights = {key: value.half() for key, value in weights.items()}
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"
save_to_safetensors(path=args.output_path, tensors=weights)
if __name__ == "__main__":
main()

View file

@ -11,6 +11,7 @@ import subprocess
import sys import sys
from urllib.parse import urlparse from urllib.parse import urlparse
import gdown
import requests import requests
from tqdm import tqdm from tqdm import tqdm
@ -446,6 +447,25 @@ def download_ic_light():
) )
def download_mvanet():
fn = "Model_80.pth"
dest_folder = os.path.join(test_weights_dir, "mvanet")
dest_filename = os.path.join(dest_folder, fn)
if os.environ.get("DRY_RUN") == "1":
return
if os.path.exists(dest_filename):
print(f"✖️ Skipping previously downloaded mvanet/{fn}")
else:
os.makedirs(dest_folder, exist_ok=True)
print(f"🔽 Downloading mvanet/{fn} => '{rel(dest_filename)}'", end="\n")
gdown.download(id="1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv", output=dest_filename, quiet=True)
print(f"{previous_line}✅ Downloaded mvanet/{fn} => '{rel(dest_filename)}' ")
check_hash(dest_filename, "b915d492")
def printg(msg: str): def printg(msg: str):
"""print in green color""" """print in green color"""
print("\033[92m" + msg + "\033[0m") print("\033[92m" + msg + "\033[0m")
@ -808,6 +828,16 @@ def convert_ic_light():
) )
def convert_mvanet():
run_conversion_script(
"convert_mvanet.py",
"tests/weights/mvanet/Model_80.pth",
"tests/weights/mvanet/mvanet.safetensors",
half=True,
expected_hash="bf9ae4cb",
)
def download_all(): def download_all():
print(f"\nAll weights will be downloaded to {test_weights_dir}\n") print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
download_sd15("runwayml/stable-diffusion-v1-5") download_sd15("runwayml/stable-diffusion-v1-5")
@ -830,6 +860,7 @@ def download_all():
download_sdxl_lightning_base() download_sdxl_lightning_base()
download_sdxl_lightning_lora() download_sdxl_lightning_lora()
download_ic_light() download_ic_light()
download_mvanet()
def convert_all(): def convert_all():
@ -850,6 +881,7 @@ def convert_all():
convert_lcm_base() convert_lcm_base()
convert_sdxl_lightning_base() convert_sdxl_lightning_base()
convert_ic_light() convert_ic_light()
convert_mvanet()
def main(): def main():

View file

@ -0,0 +1,3 @@
from .swin_transformer import SwinTransformer
__all__ = ["SwinTransformer"]

View file

@ -0,0 +1,3 @@
from .mvanet import MVANet
__all__ = ["MVANet"]

View file

@ -0,0 +1,138 @@
import re
from torch import Tensor
def convert_weights(official_state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
rm_list = [
# Official weights contains useless keys
# See https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425
r"multifieldcrossatt.linear[56]",
r"multifieldcrossatt.attention.5",
r"dec_blk\d+\.linear[12]",
r"dec_blk[1234]\.attention\.[4567]",
# We don't need the sideout weights
r"sideout\d+",
]
state_dict = {k: v for k, v in official_state_dict.items() if not any(re.match(rm, k) for rm in rm_list)}
keys_map: dict[str, str] = {}
for k in state_dict.keys():
v: str = k
def rpfx(s: str, src: str, dst: str) -> str:
if not s.startswith(src):
return s
return s.replace(src, dst, 1)
# Swin Transformer backbone
v = rpfx(v, "backbone.patch_embed.proj.", "SwinTransformer.PatchEmbedding.Conv2d.")
v = rpfx(v, "backbone.patch_embed.norm.", "SwinTransformer.PatchEmbedding.LayerNorm.")
if m := re.match(r"backbone\.layers\.(\d+)\.downsample\.(.*)", v):
s = m.group(2).replace("reduction.", "Linear.").replace("norm.", "LayerNorm.")
v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.PatchMerging.{s}"
if m := re.match(r"backbone\.layers\.(\d+)\.blocks\.(\d+)\.(.*)", v):
s = m.group(3)
s = s.replace("norm1.", "Residual_1.LayerNorm.")
s = s.replace("norm2.", "Residual_2.LayerNorm.")
s = s.replace("attn.qkv.", "Residual_1.WindowAttention.Linear_1.")
s = s.replace("attn.proj.", "Residual_1.WindowAttention.Linear_2.")
s = s.replace("attn.relative_position", "Residual_1.WindowAttention.WindowSDPA.rpb.relative_position")
s = s.replace("mlp.fc", "Residual_2.Linear_")
v = ".".join(
[
f"SwinTransformer.Chain_{int(m.group(1)) + 1}",
f"BasicLayer.SwinTransformerBlock_{int(m.group(2)) + 1}",
s,
]
)
if m := re.match(r"backbone\.norm(\d+)\.(.*)", v):
v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.Passthrough.LayerNorm.{m.group(2)}"
# MVANet
def mclm(s: str, pfx_src: str, pfx_dst: str) -> str:
pca = f"{pfx_dst}Residual.PatchwiseCrossAttention"
s = rpfx(s, f"{pfx_src}linear1.", f"{pfx_dst}FeedForward_1.Linear_1.")
s = rpfx(s, f"{pfx_src}linear2.", f"{pfx_dst}FeedForward_1.Linear_2.")
s = rpfx(s, f"{pfx_src}linear3.", f"{pfx_dst}FeedForward_2.Linear_1.")
s = rpfx(s, f"{pfx_src}linear4.", f"{pfx_dst}FeedForward_2.Linear_2.")
s = rpfx(s, f"{pfx_src}norm1.", f"{pfx_dst}LayerNorm_1.")
s = rpfx(s, f"{pfx_src}norm2.", f"{pfx_dst}LayerNorm_2.")
s = rpfx(s, f"{pfx_src}attention.0.", f"{pfx_dst}GlobalAttention.Sum.Chain.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.4.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.")
return s
def mcrm(s: str, pfx_src: str, pfx_dst: str) -> str:
# Note: there are no linear{1,2}, see https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425
tca = f"{pfx_dst}Parallel_3.TiledCrossAttention"
pca = f"{tca}.Sum.Chain_2.PatchwiseCrossAttention"
s = rpfx(s, f"{pfx_src}linear3.", f"{tca}.FeedForward.Linear_1.")
s = rpfx(s, f"{pfx_src}linear4.", f"{tca}.FeedForward.Linear_2.")
s = rpfx(s, f"{pfx_src}norm1.", f"{tca}.LayerNorm_1.")
s = rpfx(s, f"{pfx_src}norm2.", f"{tca}.LayerNorm_2.")
s = rpfx(s, f"{pfx_src}attention.0.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.")
s = rpfx(s, f"{pfx_src}sal_conv.", f"{pfx_dst}Parallel_2.Multiply.Chain.Conv2d.")
return s
def cbr(s: str, pfx_src: str, pfx_dst: str, shift: int = 0) -> str:
s = rpfx(s, f"{pfx_src}{shift}.", f"{pfx_dst}Conv2d.")
s = rpfx(s, f"{pfx_src}{shift + 1}.", f"{pfx_dst}BatchNorm2d.")
s = rpfx(s, f"{pfx_src}{shift + 2}.", f"{pfx_dst}PReLU.")
return s
def cbg(s: str, pfx_src: str, pfx_dst: str) -> str:
s = rpfx(s, f"{pfx_src}0.", f"{pfx_dst}Conv2d.")
s = rpfx(s, f"{pfx_src}1.", f"{pfx_dst}BatchNorm2d.")
return s
v = rpfx(v, "shallow.0.", "ComputeShallow.Conv2d.")
v = cbr(v, "output1.", "Pyramid.Sum.Chain.CBR.")
v = cbr(v, "output2.", "Pyramid.Sum.PyramidL2.Sum.Chain.CBR.")
v = cbr(v, "output3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.Chain.CBR.")
v = cbr(v, "output4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.Chain.CBR.")
v = cbr(v, "output5.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.CBR.")
v = cbr(v, "conv1.", "Pyramid.CBR.")
v = cbr(v, "conv2.", "Pyramid.Sum.PyramidL2.CBR.")
v = cbr(v, "conv3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.CBR.")
v = cbr(v, "conv4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.CBR.")
v = mclm(v, "multifieldcrossatt.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.MCLM.")
v = mcrm(v, "dec_blk1.", "Pyramid.MCRM.")
v = mcrm(v, "dec_blk2.", "Pyramid.Sum.PyramidL2.MCRM.")
v = mcrm(v, "dec_blk3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.MCRM.")
v = mcrm(v, "dec_blk4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.MCRM.")
v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_1.")
v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_2.", shift=3)
v = rpfx(v, "insmask_head.6.", "RearrangeMultiView.Chain.Conv2d.")
v = cbg(v, "upsample1.", "ShallowUpscaler.Sum_2.Chain_1.CBG.")
v = cbg(v, "upsample2.", "ShallowUpscaler.CBG.")
v = rpfx(v, "output.0.", "Conv2d.")
if v != k:
keys_map[k] = v
for key, new_key in keys_map.items():
state_dict[new_key] = state_dict[key]
state_dict.pop(key)
return state_dict

View file

@ -0,0 +1,211 @@
# Multi-View Complementary Localization
import math
import torch
from torch import Tensor, device as Device
import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
from .utils import FeedForward, MultiheadAttention, MultiPool, PatchMerge, PatchwiseCrossAttention, Unflatten
class PerPixel(fl.Chain):
"""(B, C, H, W) -> H*W, B, C"""
def __init__(self):
super().__init__(
fl.Permute(2, 3, 0, 1),
fl.Flatten(0, 1),
)
class PositionEmbeddingSine(fl.Module):
"""
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
"""
def __init__(self, num_pos_feats: int, device: Device | None = None):
super().__init__()
self.device = device
temperature = 10000
self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32, device=self.device)
self.dim_t = temperature ** (2 * (self.dim_t // 2) / num_pos_feats)
def __call__(self, h: int, w: int) -> Tensor:
mask = torch.ones([1, h, w, 1], dtype=torch.bool, device=self.device)
y_embed = mask.cumsum(dim=1, dtype=torch.float32)
x_embed = mask.cumsum(dim=2, dtype=torch.float32)
eps, scale = 1e-6, 2 * math.pi
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * scale
pos_x = x_embed / self.dim_t
pos_y = y_embed / self.dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
return torch.cat((pos_y, pos_x), dim=3).permute(1, 2, 0, 3).flatten(0, 1)
class MultiPoolPos(fl.Module):
def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine):
super().__init__()
self.pool_ratios = pool_ratios
self.positional_embedding = positional_embedding
def forward(self, *args: int) -> Tensor:
h, w = args
return torch.cat([self.positional_embedding(h // ratio, w // ratio) for ratio in self.pool_ratios])
class Repeat(fl.Module):
def __init__(self, dim: int = 0):
self.dim = dim
super().__init__()
def forward(self, x: Tensor, n: int) -> Tensor:
return torch.repeat_interleave(x, n, dim=self.dim)
class _MHA_Arg(fl.Sum):
def __init__(self, offset: int):
self.offset = offset
super().__init__(
fl.GetArg(offset), # value
fl.Chain(
fl.Parallel(
fl.GetArg(self.offset + 1), # position embedding
fl.Lambda(self._batch_size),
),
Repeat(1),
),
)
def _batch_size(self, *args: Tensor) -> int:
return args[self.offset].size(1)
class GlobalAttention(fl.Chain):
# Input must be a 4-tuple: (global, global pos. emb, pools, pools pos. emb.)
def __init__(
self,
emb_dim: int,
num_heads: int = 1,
device: Device | None = None,
):
super().__init__(
fl.Sum(
fl.GetArg(0), # global
fl.Chain(
fl.Parallel(
_MHA_Arg(0), # Q: global + pos. emb
_MHA_Arg(2), # K: pools + pos. emb
fl.GetArg(2), # V: pools
),
MultiheadAttention(emb_dim, num_heads, device=device),
),
),
)
class MCLM(fl.Chain):
"""Multi-View Complementary Localization Module
Inputs:
tensor: (b, 5, e, h, h)
Outputs:
tensor: (b, 5, e, h, h)
"""
def __init__(
self,
emb_dim: int,
num_heads: int = 1,
pool_ratios: list[int] | None = None,
device: Device | None = None,
):
if pool_ratios is None:
pool_ratios = [2, 8, 16]
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device)
# LayerNorms in MCLM share their weights.
ln1 = fl.LayerNorm(emb_dim, device=device)
ln2 = fl.LayerNorm(emb_dim, device=device)
def proxy(m: fl.Module) -> fl.Module:
def f(x: Tensor) -> Tensor:
return m(x)
return fl.Lambda(f)
super().__init__(
fl.Parallel(
fl.Chain( # global
fl.Slicing(dim=1, start=4),
fl.Squeeze(1),
fl.Parallel(
PerPixel(), # glb
fl.Chain( # g_pos
fl.Lambda(lambda x: x.shape[-2:]), # type: ignore
positional_embedding,
),
),
),
fl.Chain( # local
fl.Slicing(dim=1, end=4),
fl.SetContext("mclm", "local"),
PatchMerge(),
fl.Parallel(
fl.Chain( # pool
MultiPool(pool_ratios),
fl.Squeeze(0),
),
fl.Chain( # pool_pos
fl.Lambda(lambda x: x.shape[-2:]), # type: ignore
MultiPoolPos(pool_ratios, positional_embedding),
),
),
),
),
fl.Lambda(lambda t1, t2: (*t1, *t2)), # type: ignore
GlobalAttention(emb_dim, num_heads, device=device),
ln1,
FeedForward(emb_dim, device=device),
ln2,
fl.SetContext("mclm", "global"),
fl.UseContext("mclm", "local"),
fl.Flatten(-2, -1),
fl.Permute(1, 3, 0, 2),
fl.Residual(
fl.Parallel(
fl.Identity(),
fl.Chain(
fl.UseContext("mclm", "global"),
Unflatten(0, (2, 8, 2, 8)), # 2, h/2, 2, h/2
fl.Permute(0, 2, 1, 3, 4, 5),
fl.Flatten(0, 1),
fl.Flatten(1, 2),
),
),
PatchwiseCrossAttention(emb_dim, num_heads, device=device),
),
proxy(ln1),
FeedForward(emb_dim, device=device),
proxy(ln2),
fl.Concatenate(
fl.Identity(),
fl.Chain(
fl.UseContext("mclm", "global"),
fl.Unsqueeze(0),
),
),
Unflatten(1, (16, 16)), # h, h
fl.Permute(3, 0, 4, 1, 2),
)
def init_context(self) -> Contexts:
return {"mclm": {"global": None, "local": None}}

View file

@ -0,0 +1,119 @@
# Multi-View Complementary Refinement
import torch
from torch import Tensor, device as Device
import refiners.fluxion.layers as fl
from .utils import FeedForward, Interpolate, MultiPool, PatchMerge, PatchSplit, PatchwiseCrossAttention, Unflatten
class Multiply(fl.Chain):
def __init__(self, o1: fl.Module, o2: fl.Module) -> None:
super().__init__(o1, o2)
def forward(self, *args: Tensor) -> Tensor:
return torch.mul(self[0](*args), self[1](*args))
class TiledCrossAttention(fl.Chain):
def __init__(
self,
emb_dim: int,
dim: int,
num_heads: int = 1,
pool_ratios: list[int] | None = None,
device: Device | None = None,
):
# Input must be a 4-tuple: (local, global)
if pool_ratios is None:
pool_ratios = [1, 2, 4]
super().__init__(
fl.Distribute(
fl.Chain( # local
fl.Flatten(-2, -1),
fl.Permute(1, 3, 0, 2),
),
fl.Chain( # global
PatchSplit(),
fl.Squeeze(0),
MultiPool(pool_ratios),
),
),
fl.Sum(
fl.Chain(
fl.GetArg(0),
fl.Permute(2, 1, 0, 3),
),
fl.Chain(
PatchwiseCrossAttention(emb_dim, num_heads, device=device),
fl.Permute(2, 1, 0, 3),
),
),
fl.LayerNorm(emb_dim, device=device),
FeedForward(emb_dim, device=device),
fl.LayerNorm(emb_dim, device=device),
fl.Permute(0, 2, 3, 1),
Unflatten(-1, (dim, dim)),
)
class MCRM(fl.Chain):
"""Multi-View Complementary Refinement"""
def __init__(
self,
emb_dim: int,
size: int,
num_heads: int = 1,
pool_ratios: list[int] | None = None,
device: Device | None = None,
):
if pool_ratios is None:
pool_ratios = [1, 2, 4]
super().__init__(
fl.Parallel(
fl.Chain( # local
fl.Slicing(dim=1, end=4),
),
fl.Chain( # global
fl.Slicing(dim=1, start=4),
fl.Squeeze(1),
),
),
fl.Parallel(
Multiply(
fl.GetArg(0),
fl.Chain(
fl.GetArg(1),
fl.Conv2d(emb_dim, 1, 1, device=device),
fl.Sigmoid(),
Interpolate((size * 2, size * 2), "nearest"),
PatchSplit(),
),
),
fl.GetArg(1),
),
fl.Parallel(
TiledCrossAttention(emb_dim, size, num_heads, pool_ratios, device=device),
fl.GetArg(1),
),
fl.Concatenate(
fl.GetArg(0),
fl.Chain(
fl.Sum(
fl.GetArg(1),
fl.Chain(
fl.GetArg(0),
PatchMerge(),
Interpolate((size, size), "nearest"),
),
),
fl.Unsqueeze(1),
),
dim=1,
),
)

View file

@ -0,0 +1,337 @@
# Multi-View Aggregation Network (arXiv:2404.07445)
from torch import device as Device
import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
from refiners.foundationals.swin.swin_transformer import SwinTransformer
from .mclm import MCLM # Multi-View Complementary Localization
from .mcrm import MCRM # Multi-View Complementary Refinement
from .utils import BatchNorm2d, Interpolate, PatchMerge, PatchSplit, PReLU, Rescale, Unflatten
class CBG(fl.Chain):
"""(C)onvolution + (B)atchNorm + (G)eLU"""
def __init__(
self,
in_dim: int,
out_dim: int | None = None,
device: Device | None = None,
):
out_dim = out_dim or in_dim
super().__init__(
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
BatchNorm2d(out_dim, device=device),
fl.GeLU(),
)
class CBR(fl.Chain):
"""(C)onvolution + (B)atchNorm + Parametric (R)eLU"""
def __init__(
self,
in_dim: int,
out_dim: int | None = None,
device: Device | None = None,
):
out_dim = out_dim or in_dim
super().__init__(
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
BatchNorm2d(out_dim, device=device),
PReLU(device=device),
)
class SplitMultiView(fl.Chain):
"""
Split a hd tensor into 5 ld views, (5 = 1 global + 4 tiles)
See also the reverse Module [`RearrangeMultiView`][refiners.foundationals.swin.mvanet.RearrangeMultiView]
Inputs:
single_view (b, c, H, W)
Outputs:
multi_view (b, 5, c, H/2, W/2)
"""
def __init__(self):
super().__init__(
fl.Concatenate(
PatchSplit(), # global features
fl.Chain( # local features
Rescale(scale_factor=0.5, mode="bilinear"),
fl.Unsqueeze(1),
),
dim=1,
)
)
class ShallowUpscaler(fl.Chain):
"""4x Upscaler reusing the image as input to upscale the feature
See [[arXiv:2108.10257] SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/abs/2108.10257)
Args:
embedding_dim (int): the embedding dimension
Inputs:
feature (b, E, image_size/4, image_size/4)
Output:
upscaled tensor (b, E, image_size, image_size)
"""
def __init__(
self,
embedding_dim: int = 128,
device: Device | None = None,
):
super().__init__(
fl.Sum(
fl.Identity(),
fl.Chain(
fl.UseContext("mvanet", "shallow"),
Interpolate((256, 256)),
),
),
fl.Sum(
fl.Chain(
Rescale(2),
CBG(embedding_dim, device=device),
),
fl.Chain(
fl.UseContext("mvanet", "shallow"),
Interpolate((512, 512)),
),
),
Rescale(2),
CBG(embedding_dim, device=device),
)
class PyramidL5(fl.Chain):
def __init__(
self,
embedding_dim: int = 128,
device: Device | None = None,
):
super().__init__(
fl.GetArg(0), # output5
fl.Flatten(0, 1),
CBR(1024, embedding_dim, device=device),
Unflatten(0, (-1, 5)),
MCLM(embedding_dim, device=device),
fl.Flatten(0, 1),
Interpolate((32, 32)),
)
class PyramidL4(fl.Chain):
def __init__(
self,
embedding_dim: int = 128,
device: Device | None = None,
):
super().__init__(
fl.Sum(
PyramidL5(embedding_dim=embedding_dim, device=device),
fl.Chain(
fl.GetArg(1),
fl.Flatten(0, 1),
CBR(512, embedding_dim, device=device), # output4
Unflatten(0, (-1, 5)),
),
),
MCRM(embedding_dim, 32, device=device), # dec_blk4
fl.Flatten(0, 1),
CBR(embedding_dim, device=device), # conv4
Interpolate((64, 64)),
)
class PyramidL3(fl.Chain):
def __init__(
self,
embedding_dim: int = 128,
device: Device | None = None,
):
super().__init__(
fl.Sum(
PyramidL4(embedding_dim=embedding_dim, device=device),
fl.Chain(
fl.GetArg(2),
fl.Flatten(0, 1),
CBR(256, embedding_dim, device=device), # output3
Unflatten(0, (-1, 5)),
),
),
MCRM(embedding_dim, 64, device=device), # dec_blk3
fl.Flatten(0, 1),
CBR(embedding_dim, device=device), # conv3
Interpolate((128, 128)),
)
class PyramidL2(fl.Chain):
def __init__(
self,
embedding_dim: int = 128,
device: Device | None = None,
):
embedding_dim = 128
super().__init__(
fl.Sum(
PyramidL3(embedding_dim=embedding_dim, device=device),
fl.Chain(
fl.GetArg(3),
fl.Flatten(0, 1),
CBR(128, embedding_dim, device=device), # output2
Unflatten(0, (-1, 5)),
),
),
MCRM(embedding_dim, 128, device=device), # dec_blk2
fl.Flatten(0, 1),
CBR(embedding_dim, device=device), # conv2
Interpolate((128, 128)),
)
class Pyramid(fl.Chain):
"""
Recursive Pyramidal Network calling MCLM and MCRM blocks
It acts as a FPN (Feature Pyramid Network) Neck for MVANet
see [[arXiv:1612.03144] Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
Inputs:
features: a pyramid of N = 5 tensors
shapes are (b, 5, E_{0}, S_{0}, S_{0}), ..., (b, 5, E_{1}, S_{i}, S_{i}), ..., (b, 5, E_{N-1}, S_{N-1}, S_{N-1})
with S_{i} = S_{i-1} or S_{i} = 2*S_{i-1} for 0 < i < N
Outputs:
output (b, 5, E, S_{N-1}, S_{N-1})
"""
def __init__(
self,
embedding_dim: int = 128,
device: Device | None = None,
):
super().__init__(
fl.Sum(
PyramidL2(embedding_dim=embedding_dim, device=device),
fl.Chain(
fl.GetArg(4),
fl.Flatten(0, 1),
CBR(128, embedding_dim, device=device), # output1
Unflatten(0, (-1, 5)),
),
),
MCRM(embedding_dim, 128, device=device), # dec_blk1
fl.Flatten(0, 1),
CBR(embedding_dim, device=device), # conv1
Unflatten(0, (-1, 5)),
)
class RearrangeMultiView(fl.Chain):
"""
Inputs:
multi_view (b, 5, E, H, W)
Outputs:
single_view (b, E, H*2, W*2)
Fusion a multi view tensor into a single view tensor, using convolutions
See also the reverse Module [`SplitMultiView`][refiners.foundationals.swin.mvanet.SplitMultiView]
"""
def __init__(
self,
embedding_dim: int = 128,
device: Device | None = None,
):
super().__init__(
fl.Sum(
fl.Chain( # local features
fl.Slicing(dim=1, end=4),
PatchMerge(),
),
fl.Chain( # global feature
fl.Slicing(dim=1, start=4),
fl.Squeeze(1),
Interpolate((256, 256)),
),
),
fl.Chain( # conv head
CBR(embedding_dim, 384, device=device),
CBR(384, device=device),
fl.Conv2d(384, embedding_dim, kernel_size=3, padding=1, device=device),
),
)
class ComputeShallow(fl.Passthrough):
def __init__(
self,
embedding_dim: int = 128,
device: Device | None = None,
):
super().__init__(
fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device),
fl.SetContext("mvanet", "shallow"),
)
class MVANet(fl.Chain):
"""Multi-view Aggregation Network for Dichotomous Image Segmentation
See [[arXiv:2404.07445] Multi-view Aggregation Network for Dichotomous Image Segmentation](https://arxiv.org/abs/2404.07445) for more details.
Args:
embedding_dim (int): embedding dimension
n_logits (int): the number of output logits (default to 1)
1 logit is used for alpha matting/foreground-background segmentation/sod segmentation
depths (list[int]): see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer]
num_heads (list[int]): see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer]
window_size (int): default to 12, see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer]
device (Device | None): the device to use
"""
def __init__(
self,
embedding_dim: int = 128,
n_logits: int = 1,
depths: list[int] | None = None,
num_heads: list[int] | None = None,
window_size: int = 12,
device: Device | None = None,
):
if depths is None:
depths = [2, 2, 18, 2]
if num_heads is None:
num_heads = [4, 8, 16, 32]
super().__init__(
ComputeShallow(embedding_dim=embedding_dim, device=device),
SplitMultiView(),
fl.Flatten(0, 1),
SwinTransformer(
embedding_dim=embedding_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
device=device,
),
fl.Distribute(*(Unflatten(0, (-1, 5)) for _ in range(5))),
Pyramid(embedding_dim=embedding_dim, device=device),
RearrangeMultiView(embedding_dim=embedding_dim, device=device),
ShallowUpscaler(embedding_dim, device=device),
fl.Conv2d(embedding_dim, n_logits, kernel_size=3, padding=1, device=device),
)
def init_context(self) -> Contexts:
return {"mvanet": {"shallow": None}}

View file

@ -0,0 +1,173 @@
import torch
from torch import Size, Tensor
from torch.nn.functional import (
adaptive_avg_pool2d,
interpolate, # type: ignore
)
import refiners.fluxion.layers as fl
class Unflatten(fl.Module):
def __init__(self, dim: int, sizes: tuple[int, ...]) -> None:
super().__init__()
self.dim = dim
self.sizes = Size(sizes)
def forward(self, x: Tensor) -> Tensor:
return torch.unflatten(input=x, dim=self.dim, sizes=self.sizes)
class Interpolate(fl.Module):
def __init__(self, size: tuple[int, ...], mode: str = "bilinear"):
super().__init__()
self.size = Size(size)
self.mode = mode
def forward(self, x: Tensor) -> Tensor:
return interpolate(x, size=self.size, mode=self.mode) # type: ignore
class Rescale(fl.Module):
def __init__(self, scale_factor: float, mode: str = "nearest"):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode
def forward(self, x: Tensor) -> Tensor:
return interpolate(x, scale_factor=self.scale_factor, mode=self.mode) # type: ignore
class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule):
def __init__(self, num_features: int, device: torch.device | None = None):
super().__init__(num_features=num_features, device=device) # type: ignore
class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation):
def __init__(self, device: torch.device | None = None):
super().__init__(device=device) # type: ignore
class PatchSplit(fl.Chain):
"""(B, N, H, W) -> B, 4, N, H/2, W/2"""
def __init__(self):
super().__init__(
Unflatten(-2, (2, -1)),
Unflatten(-1, (2, -1)),
fl.Permute(0, 2, 4, 1, 3, 5),
fl.Flatten(1, 2),
)
class PatchMerge(fl.Chain):
"""B, 4, N, H, W -> (B, N, 2*H, 2*W)"""
def __init__(self):
super().__init__(
Unflatten(1, (2, 2)),
fl.Permute(0, 3, 1, 4, 2, 5),
fl.Flatten(-2, -1),
fl.Flatten(-3, -2),
)
class FeedForward(fl.Residual):
def __init__(self, emb_dim: int, device: torch.device | None = None) -> None:
super().__init__(
fl.Linear(in_features=emb_dim, out_features=2 * emb_dim, device=device),
fl.ReLU(),
fl.Linear(in_features=2 * emb_dim, out_features=emb_dim, device=device),
)
class _GetArgs(fl.Parallel):
def __init__(self, n: int):
super().__init__(
fl.Chain(
fl.GetArg(0),
fl.Slicing(dim=0, start=n, end=n + 1),
fl.Squeeze(0),
),
fl.Chain(
fl.GetArg(1),
fl.Slicing(dim=0, start=n, end=n + 1),
fl.Squeeze(0),
),
fl.Chain(
fl.GetArg(1),
fl.Slicing(dim=0, start=n, end=n + 1),
fl.Squeeze(0),
),
)
class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule):
def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None):
super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore
@property
def weight(self) -> Tensor: # type: ignore
return self.in_proj_weight
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # type: ignore
return super().forward(q, k, v)[0]
class PatchwiseCrossAttention(fl.Chain):
# Input is 2 tensors of sizes (4, HW, B, C) and (4, HW', B, C),
# output is size (4, HW, B, C).
def __init__(
self,
d_model: int,
num_heads: int,
device: torch.device | None = None,
):
super().__init__(
fl.Concatenate(
fl.Chain(
_GetArgs(0),
MultiheadAttention(d_model, num_heads, device=device),
),
fl.Chain(
_GetArgs(1),
MultiheadAttention(d_model, num_heads, device=device),
),
fl.Chain(
_GetArgs(2),
MultiheadAttention(d_model, num_heads, device=device),
),
fl.Chain(
_GetArgs(3),
MultiheadAttention(d_model, num_heads, device=device),
),
),
Unflatten(0, (4, -1)),
)
class Pool(fl.Module):
def __init__(self, ratio: int) -> None:
super().__init__()
self.ratio = ratio
def forward(self, x: Tensor) -> Tensor:
b, _, h, w = x.shape
assert h % self.ratio == 0 and w % self.ratio == 0
r = adaptive_avg_pool2d(x, (h // self.ratio, w // self.ratio))
return torch.unflatten(r, 0, (b, -1))
class MultiPool(fl.Concatenate):
def __init__(self, pool_ratios: list[int]) -> None:
super().__init__(
*(
fl.Chain(
Pool(pool_ratio),
fl.Flatten(-2, -1),
fl.Permute(0, 3, 1, 2),
)
for pool_ratio in pool_ratios
),
dim=1,
)

View file

@ -0,0 +1,391 @@
# Swin Transformer (arXiv:2103.14030)
#
# Specific to MVANet, only supports square inputs.
# Originally adapted from the version in MVANet and InSPyReNet (https://github.com/plemeri/InSPyReNet)
# Original implementation by Microsoft at https://github.com/microsoft/Swin-Transformer
import functools
from math import isqrt
import torch
from torch import Tensor, device as Device
import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
def to_windows(x: Tensor, window_size: int) -> Tensor:
B, H, W, C = x.shape
assert W == H and H % window_size == 0
x = x.reshape(B, H // window_size, window_size, W // window_size, window_size, C)
return x.permute(0, 1, 3, 2, 4, 5).reshape(B, -1, window_size * window_size, C)
class ToWindows(fl.Module):
def __init__(self, window_size: int):
super().__init__()
self.window_size = window_size
def forward(self, x: Tensor) -> Tensor:
return to_windows(x, self.window_size)
class FromWindows(fl.Module):
def forward(self, x: Tensor) -> Tensor:
B, num_windows, window_size_2, C = x.shape
window_size = isqrt(window_size_2)
H = isqrt(num_windows * window_size_2)
x = x.reshape(B, H // window_size, H // window_size, window_size, window_size, C)
return x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, H, C)
@functools.cache
def get_attn_mask(H: int, window_size: int, device: Device | None = None) -> Tensor:
assert H % window_size == 0
shift_size = window_size // 2
img_mask = torch.zeros((1, H, H, 1), device=device)
h_slices = (
slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None),
)
w_slices = (
slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = to_windows(img_mask, window_size).squeeze() # B, nW, window_size * window_size, [1]
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask.masked_fill_(attn_mask != 0, -100.0).masked_fill_(attn_mask == 0, 0.0)
return attn_mask
class Pad(fl.Module):
def __init__(self, step: int):
super().__init__()
self.step = step
def forward(self, x: Tensor) -> Tensor:
B, H, W, C = x.shape
assert W == H
if H % self.step == 0:
return x
p = self.step * ((H + self.step - 1) // self.step)
padded = torch.zeros(B, p, p, C, device=x.device, dtype=x.dtype)
padded[:, :H, :H, :] = x
return padded
class StatefulPad(fl.Chain):
def __init__(self, context: str, key: str, step: int) -> None:
super().__init__(
fl.SetContext(context=context, key=key, callback=self._push),
Pad(step=step),
)
def _push(self, sizes: list[int], x: Tensor) -> None:
sizes.append(x.size(1))
class StatefulUnpad(fl.Chain):
def __init__(self, context: str, key: str) -> None:
super().__init__(
fl.Parallel(
fl.Identity(),
fl.UseContext(context=context, key=key).compose(lambda x: x.pop()),
),
fl.Lambda(self._unpad),
)
@staticmethod
def _unpad(x: Tensor, size: int) -> Tensor:
return x[:, :size, :size, :]
class SquareUnflatten(fl.Module):
# ..., L^2, ... -> ..., L, L, ...
def __init__(self, dim: int = 0) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
d = isqrt(x.shape[self.dim])
return torch.unflatten(x, self.dim, (d, d))
class WindowUnflatten(fl.Module):
# ..., H, ... -> ..., H // ws, ws, ...
def __init__(self, window_size: int, dim: int = 0) -> None:
super().__init__()
self.window_size = window_size
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.dim] % self.window_size == 0
H = x.shape[self.dim]
return torch.unflatten(x, self.dim, (H // self.window_size, self.window_size))
class Roll(fl.Module):
def __init__(self, *shifts: tuple[int, int]):
super().__init__()
self.shifts = shifts
self._dims = tuple(s[0] for s in shifts)
self._shifts = tuple(s[1] for s in shifts)
def forward(self, x: Tensor) -> Tensor:
return torch.roll(x, self._shifts, self._dims)
class RelativePositionBias(fl.Module):
relative_position_index: Tensor
def __init__(self, window_size: int, num_heads: int, device: Device | None = None):
super().__init__()
self.relative_position_bias_table = torch.nn.Parameter(
torch.empty(
(2 * window_size - 1) * (2 * window_size - 1),
num_heads,
device=device,
)
)
relative_position_index = torch.empty(
window_size**2,
window_size**2,
device=device,
dtype=torch.int64,
)
self.register_buffer("relative_position_index", relative_position_index)
def forward(self) -> Tensor:
# Yes, this is a (trainable) constant.
return self.relative_position_bias_table[self.relative_position_index].permute(2, 0, 1).unsqueeze(0)
class WindowSDPA(fl.Module):
def __init__(
self,
dim: int,
window_size: int,
num_heads: int,
shift: bool = False,
device: Device | None = None,
):
super().__init__()
self.window_size = window_size
self.num_heads = num_heads
self.shift = shift
self.rpb = RelativePositionBias(window_size, num_heads, device=device)
def forward(self, x: Tensor):
B, num_windows, N, _C = x.shape
assert _C % (3 * self.num_heads) == 0
C = _C // 3
x = torch.reshape(x, (B * num_windows, N, 3, self.num_heads, C // self.num_heads))
q, k, v = x.permute(2, 0, 3, 1, 4)
attn_mask = self.rpb()
if self.shift:
mask = get_attn_mask(isqrt(num_windows * (self.window_size**2)), self.window_size, x.device)
mask = mask.reshape(1, num_windows, 1, N, N)
mask = mask.expand(B, -1, self.num_heads, -1, -1)
attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
x = x.transpose(1, 2).reshape(B, num_windows, N, C)
return x
class WindowAttention(fl.Chain):
"""
Window-based Multi-head Self-Attenion (W-MSA), optionally shifted (SW-MSA).
It has a trainable relative position bias (RelativePositionBias).
The input projection is stored as a single Linear for q, k and v.
"""
def __init__(
self,
dim: int,
window_size: int,
num_heads: int,
shift: bool = False,
device: Device | None = None,
):
super().__init__(
fl.Linear(dim, dim * 3, bias=True, device=device),
WindowSDPA(dim, window_size, num_heads, shift, device=device),
fl.Linear(dim, dim, device=device),
)
class SwinTransformerBlock(fl.Chain):
def __init__(
self,
dim: int,
num_heads: int,
window_size: int = 7,
shift_size: int = 0,
mlp_ratio: float = 4.0,
device: Device | None = None,
):
assert 0 <= shift_size < window_size, "shift_size must in [0, window_size["
super().__init__(
fl.Residual(
fl.LayerNorm(dim, device=device),
SquareUnflatten(1),
StatefulPad(context="padding", key="sizes", step=window_size),
Roll((1, -shift_size), (2, -shift_size)),
ToWindows(window_size),
WindowAttention(
dim,
window_size=window_size,
num_heads=num_heads,
shift=shift_size > 0,
device=device,
),
FromWindows(),
Roll((1, shift_size), (2, shift_size)),
StatefulUnpad(context="padding", key="sizes"),
fl.Flatten(1, 2),
),
fl.Residual(
fl.LayerNorm(dim, device=device),
fl.Linear(dim, int(dim * mlp_ratio), device=device),
fl.GeLU(),
fl.Linear(int(dim * mlp_ratio), dim, device=device),
),
)
def init_context(self) -> Contexts:
return {"padding": {"sizes": []}}
class PatchMerging(fl.Chain):
def __init__(self, dim: int, device: Device | None = None):
super().__init__(
SquareUnflatten(1),
Pad(2),
WindowUnflatten(2, 2),
WindowUnflatten(2, 1),
fl.Permute(0, 1, 3, 4, 2, 5),
fl.Flatten(3),
fl.Flatten(1, 2),
fl.LayerNorm(4 * dim, device=device),
fl.Linear(4 * dim, 2 * dim, bias=False, device=device),
)
class BasicLayer(fl.Chain):
def __init__(
self,
dim: int,
depth: int,
num_heads: int,
window_size: int = 7,
mlp_ratio: float = 4.0,
device: Device | None = None,
):
super().__init__(
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
device=device,
)
for i in range(depth)
)
class PatchEmbedding(fl.Chain):
def __init__(
self,
patch_size: tuple[int, int] = (4, 4),
in_chans: int = 3,
embedding_dim: int = 96,
device: Device | None = None,
):
super().__init__(
fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device),
fl.Flatten(2),
fl.Transpose(1, 2),
fl.LayerNorm(embedding_dim, device=device),
)
class SwinTransformer(fl.Chain):
"""Swin Transformer (arXiv:2103.14030)
Currently specific to MVANet, only supports square inputs.
"""
def __init__(
self,
patch_size: tuple[int, int] = (4, 4),
in_chans: int = 3,
embedding_dim: int = 96,
depths: list[int] | None = None,
num_heads: list[int] | None = None,
window_size: int = 7, # image size is 32 * this
mlp_ratio: float = 4.0,
device: Device | None = None,
):
if depths is None:
depths = [2, 2, 6, 2]
if num_heads is None:
num_heads = [3, 6, 12, 24]
self.num_layers = len(depths)
assert len(num_heads) == self.num_layers
super().__init__(
PatchEmbedding(
patch_size=patch_size,
in_chans=in_chans,
embedding_dim=embedding_dim,
device=device,
),
fl.Passthrough(
fl.Transpose(1, 2),
SquareUnflatten(2),
fl.SetContext("swin", "outputs", callback=lambda t, x: t.append(x)),
),
*(
fl.Chain(
BasicLayer(
dim=int(embedding_dim * 2**i),
depth=depths[i],
num_heads=num_heads[i],
window_size=window_size,
mlp_ratio=mlp_ratio,
device=device,
),
fl.Passthrough(
fl.LayerNorm(int(embedding_dim * 2**i), device=device),
fl.Transpose(1, 2),
SquareUnflatten(2),
fl.SetContext("swin", "outputs", callback=lambda t, x: t.insert(0, x)),
),
PatchMerging(dim=int(embedding_dim * 2**i), device=device)
if i < self.num_layers - 1
else fl.UseContext("swin", "outputs").compose(lambda t: tuple(t)),
)
for i in range(self.num_layers)
),
)
def init_context(self) -> Contexts:
return {"swin": {"outputs": []}}

59
tests/e2e/test_mvanet.py Normal file
View file

@ -0,0 +1,59 @@
from pathlib import Path
from warnings import warn
import pytest
import torch
from PIL import Image
from tests.utils import ensure_similar_images
from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image
from refiners.foundationals.swin.mvanet import MVANet
def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore
@pytest.fixture(scope="module")
def ref_path(test_e2e_path: Path) -> Path:
return test_e2e_path / "test_mvanet_ref"
@pytest.fixture(scope="module")
def ref_cactus(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "cactus.png").convert("RGB")
@pytest.fixture
def expected_cactus_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cactus_mask.png")
@pytest.fixture(scope="module")
def mvanet_weights(test_weights_path: Path) -> Path:
weights = test_weights_path / "mvanet" / "mvanet.safetensors"
if not weights.is_file():
warn(f"could not find weights at {test_weights_path}, skipping")
pytest.skip(allow_module_level=True)
return weights
@pytest.fixture
def mvanet_model(mvanet_weights: Path, test_device: torch.device) -> MVANet:
model = MVANet(device=test_device).eval() # .eval() is important!
model.load_from_safetensors(mvanet_weights)
return model
@no_grad()
def test_mvanet(
mvanet_model: MVANet,
ref_cactus: Image.Image,
expected_cactus_mask: Image.Image,
test_device: torch.device,
):
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid()
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))

View file

@ -0,0 +1,3 @@
`cactus.png` is cropped from this image: https://www.freepik.com/free-photo/laptop-notebook-pen-coffee-cup-plants-wooden-desk_269339828.htm
`expected_cactus_mask.png` has been generated using the [official MVANet codebase](https://github.com/qianyu-dlut/MVANet) and weights.

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

View file

@ -0,0 +1 @@
def download(id: str, output: str, quiet: bool = False) -> str: ...