mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
Add Multi-View Aggregation Network (MVANet)
Co-authored-by: Pierre Colle <piercus@gmail.com>
This commit is contained in:
parent
58c1cc7cd4
commit
10dfa73a09
|
@ -9,4 +9,4 @@
|
|||
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> DINOv2](foundationals/dinov2.md)
|
||||
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Latent Diffusion](foundationals/latent_diffusion.md)
|
||||
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Segment Anything](foundationals/segment_anything.md)
|
||||
|
||||
* [<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Swin Transformers](foundationals/swin.md)
|
||||
|
|
2
docs/reference/foundationals/swin.md
Normal file
2
docs/reference/foundationals/swin.md
Normal file
|
@ -0,0 +1,2 @@
|
|||
::: refiners.foundationals.swin.swin_transformer
|
||||
::: refiners.foundationals.swin.mvanet
|
|
@ -58,6 +58,7 @@ conversion = [
|
|||
"segment-anything-py>=1.0",
|
||||
"requests>=2.26.0",
|
||||
"tqdm>=4.62.3",
|
||||
"gdown>=5.2.0",
|
||||
]
|
||||
doc = [
|
||||
# required by mkdocs to format the signatures
|
||||
|
|
|
@ -31,6 +31,8 @@ babel==2.15.0
|
|||
# via mkdocs-material
|
||||
backports-strenum==1.3.1
|
||||
# via griffe
|
||||
beautifulsoup4==4.12.3
|
||||
# via gdown
|
||||
bitsandbytes==0.43.3
|
||||
# via refiners
|
||||
black==24.4.2
|
||||
|
@ -70,6 +72,7 @@ docker-pycreds==0.4.0
|
|||
filelock==3.15.4
|
||||
# via datasets
|
||||
# via diffusers
|
||||
# via gdown
|
||||
# via huggingface-hub
|
||||
# via torch
|
||||
# via transformers
|
||||
|
@ -85,6 +88,8 @@ fsspec==2024.5.0
|
|||
# via torch
|
||||
future==1.0.0
|
||||
# via neptune
|
||||
gdown==5.2.0
|
||||
# via refiners
|
||||
ghp-import==2.1.0
|
||||
# via mkdocs
|
||||
gitdb==4.0.11
|
||||
|
@ -274,6 +279,8 @@ pyjwt==2.9.0
|
|||
pymdown-extensions==10.9
|
||||
# via mkdocs-material
|
||||
# via mkdocstrings
|
||||
pysocks==1.7.1
|
||||
# via requests
|
||||
python-dateutil==2.9.0.post0
|
||||
# via arrow
|
||||
# via botocore
|
||||
|
@ -311,6 +318,7 @@ requests==2.32.3
|
|||
# via bravado-core
|
||||
# via datasets
|
||||
# via diffusers
|
||||
# via gdown
|
||||
# via huggingface-hub
|
||||
# via mkdocs-material
|
||||
# via neptune
|
||||
|
@ -356,6 +364,8 @@ six==1.16.0
|
|||
# via rfc3339-validator
|
||||
smmap==5.0.1
|
||||
# via gitdb
|
||||
soupsieve==2.6
|
||||
# via beautifulsoup4
|
||||
swagger-spec-validator==3.0.4
|
||||
# via bravado-core
|
||||
# via neptune
|
||||
|
@ -383,6 +393,7 @@ torchvision==0.19.0
|
|||
# via timm
|
||||
tqdm==4.66.4
|
||||
# via datasets
|
||||
# via gdown
|
||||
# via huggingface-hub
|
||||
# via refiners
|
||||
# via transformers
|
||||
|
|
40
scripts/conversion/convert_mvanet.py
Normal file
40
scripts/conversion/convert_mvanet.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from refiners.fluxion.utils import load_tensors, save_to_safetensors
|
||||
from refiners.foundationals.swin.mvanet.converter import convert_weights
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--from",
|
||||
type=str,
|
||||
required=True,
|
||||
dest="source_path",
|
||||
help="A MVANet checkpoint. One can be found at https://github.com/qianyu-dlut/MVANet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to",
|
||||
type=str,
|
||||
dest="output_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Path to save the converted model. If not specified, the output path will be the source path with the"
|
||||
" extension changed to .safetensors."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--half", action="store_true", dest="half")
|
||||
args = parser.parse_args()
|
||||
|
||||
src_weights = load_tensors(args.source_path)
|
||||
weights = convert_weights(src_weights)
|
||||
if args.half:
|
||||
weights = {key: value.half() for key, value in weights.items()}
|
||||
if args.output_path is None:
|
||||
args.output_path = f"{Path(args.source_path).stem}.safetensors"
|
||||
save_to_safetensors(path=args.output_path, tensors=weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -11,6 +11,7 @@ import subprocess
|
|||
import sys
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import gdown
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -446,6 +447,25 @@ def download_ic_light():
|
|||
)
|
||||
|
||||
|
||||
def download_mvanet():
|
||||
fn = "Model_80.pth"
|
||||
dest_folder = os.path.join(test_weights_dir, "mvanet")
|
||||
dest_filename = os.path.join(dest_folder, fn)
|
||||
|
||||
if os.environ.get("DRY_RUN") == "1":
|
||||
return
|
||||
|
||||
if os.path.exists(dest_filename):
|
||||
print(f"✖️ ️ Skipping previously downloaded mvanet/{fn}")
|
||||
else:
|
||||
os.makedirs(dest_folder, exist_ok=True)
|
||||
print(f"🔽 Downloading mvanet/{fn} => '{rel(dest_filename)}'", end="\n")
|
||||
gdown.download(id="1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv", output=dest_filename, quiet=True)
|
||||
print(f"{previous_line}✅ Downloaded mvanet/{fn} => '{rel(dest_filename)}' ")
|
||||
|
||||
check_hash(dest_filename, "b915d492")
|
||||
|
||||
|
||||
def printg(msg: str):
|
||||
"""print in green color"""
|
||||
print("\033[92m" + msg + "\033[0m")
|
||||
|
@ -808,6 +828,16 @@ def convert_ic_light():
|
|||
)
|
||||
|
||||
|
||||
def convert_mvanet():
|
||||
run_conversion_script(
|
||||
"convert_mvanet.py",
|
||||
"tests/weights/mvanet/Model_80.pth",
|
||||
"tests/weights/mvanet/mvanet.safetensors",
|
||||
half=True,
|
||||
expected_hash="bf9ae4cb",
|
||||
)
|
||||
|
||||
|
||||
def download_all():
|
||||
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
||||
download_sd15("runwayml/stable-diffusion-v1-5")
|
||||
|
@ -830,6 +860,7 @@ def download_all():
|
|||
download_sdxl_lightning_base()
|
||||
download_sdxl_lightning_lora()
|
||||
download_ic_light()
|
||||
download_mvanet()
|
||||
|
||||
|
||||
def convert_all():
|
||||
|
@ -850,6 +881,7 @@ def convert_all():
|
|||
convert_lcm_base()
|
||||
convert_sdxl_lightning_base()
|
||||
convert_ic_light()
|
||||
convert_mvanet()
|
||||
|
||||
|
||||
def main():
|
||||
|
|
3
src/refiners/foundationals/swin/__init__.py
Normal file
3
src/refiners/foundationals/swin/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .swin_transformer import SwinTransformer
|
||||
|
||||
__all__ = ["SwinTransformer"]
|
3
src/refiners/foundationals/swin/mvanet/__init__.py
Normal file
3
src/refiners/foundationals/swin/mvanet/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .mvanet import MVANet
|
||||
|
||||
__all__ = ["MVANet"]
|
138
src/refiners/foundationals/swin/mvanet/converter.py
Normal file
138
src/refiners/foundationals/swin/mvanet/converter.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
import re
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def convert_weights(official_state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
rm_list = [
|
||||
# Official weights contains useless keys
|
||||
# See https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425
|
||||
r"multifieldcrossatt.linear[56]",
|
||||
r"multifieldcrossatt.attention.5",
|
||||
r"dec_blk\d+\.linear[12]",
|
||||
r"dec_blk[1234]\.attention\.[4567]",
|
||||
# We don't need the sideout weights
|
||||
r"sideout\d+",
|
||||
]
|
||||
state_dict = {k: v for k, v in official_state_dict.items() if not any(re.match(rm, k) for rm in rm_list)}
|
||||
|
||||
keys_map: dict[str, str] = {}
|
||||
for k in state_dict.keys():
|
||||
v: str = k
|
||||
|
||||
def rpfx(s: str, src: str, dst: str) -> str:
|
||||
if not s.startswith(src):
|
||||
return s
|
||||
return s.replace(src, dst, 1)
|
||||
|
||||
# Swin Transformer backbone
|
||||
|
||||
v = rpfx(v, "backbone.patch_embed.proj.", "SwinTransformer.PatchEmbedding.Conv2d.")
|
||||
v = rpfx(v, "backbone.patch_embed.norm.", "SwinTransformer.PatchEmbedding.LayerNorm.")
|
||||
|
||||
if m := re.match(r"backbone\.layers\.(\d+)\.downsample\.(.*)", v):
|
||||
s = m.group(2).replace("reduction.", "Linear.").replace("norm.", "LayerNorm.")
|
||||
v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.PatchMerging.{s}"
|
||||
|
||||
if m := re.match(r"backbone\.layers\.(\d+)\.blocks\.(\d+)\.(.*)", v):
|
||||
s = m.group(3)
|
||||
s = s.replace("norm1.", "Residual_1.LayerNorm.")
|
||||
s = s.replace("norm2.", "Residual_2.LayerNorm.")
|
||||
|
||||
s = s.replace("attn.qkv.", "Residual_1.WindowAttention.Linear_1.")
|
||||
s = s.replace("attn.proj.", "Residual_1.WindowAttention.Linear_2.")
|
||||
s = s.replace("attn.relative_position", "Residual_1.WindowAttention.WindowSDPA.rpb.relative_position")
|
||||
|
||||
s = s.replace("mlp.fc", "Residual_2.Linear_")
|
||||
v = ".".join(
|
||||
[
|
||||
f"SwinTransformer.Chain_{int(m.group(1)) + 1}",
|
||||
f"BasicLayer.SwinTransformerBlock_{int(m.group(2)) + 1}",
|
||||
s,
|
||||
]
|
||||
)
|
||||
|
||||
if m := re.match(r"backbone\.norm(\d+)\.(.*)", v):
|
||||
v = f"SwinTransformer.Chain_{int(m.group(1)) + 1}.Passthrough.LayerNorm.{m.group(2)}"
|
||||
|
||||
# MVANet
|
||||
|
||||
def mclm(s: str, pfx_src: str, pfx_dst: str) -> str:
|
||||
pca = f"{pfx_dst}Residual.PatchwiseCrossAttention"
|
||||
s = rpfx(s, f"{pfx_src}linear1.", f"{pfx_dst}FeedForward_1.Linear_1.")
|
||||
s = rpfx(s, f"{pfx_src}linear2.", f"{pfx_dst}FeedForward_1.Linear_2.")
|
||||
s = rpfx(s, f"{pfx_src}linear3.", f"{pfx_dst}FeedForward_2.Linear_1.")
|
||||
s = rpfx(s, f"{pfx_src}linear4.", f"{pfx_dst}FeedForward_2.Linear_2.")
|
||||
s = rpfx(s, f"{pfx_src}norm1.", f"{pfx_dst}LayerNorm_1.")
|
||||
s = rpfx(s, f"{pfx_src}norm2.", f"{pfx_dst}LayerNorm_2.")
|
||||
s = rpfx(s, f"{pfx_src}attention.0.", f"{pfx_dst}GlobalAttention.Sum.Chain.MultiheadAttention.")
|
||||
s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.")
|
||||
s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.")
|
||||
s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.")
|
||||
s = rpfx(s, f"{pfx_src}attention.4.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.")
|
||||
return s
|
||||
|
||||
def mcrm(s: str, pfx_src: str, pfx_dst: str) -> str:
|
||||
# Note: there are no linear{1,2}, see https://github.com/qianyu-dlut/MVANet/issues/3#issuecomment-2105650425
|
||||
tca = f"{pfx_dst}Parallel_3.TiledCrossAttention"
|
||||
pca = f"{tca}.Sum.Chain_2.PatchwiseCrossAttention"
|
||||
s = rpfx(s, f"{pfx_src}linear3.", f"{tca}.FeedForward.Linear_1.")
|
||||
s = rpfx(s, f"{pfx_src}linear4.", f"{tca}.FeedForward.Linear_2.")
|
||||
s = rpfx(s, f"{pfx_src}norm1.", f"{tca}.LayerNorm_1.")
|
||||
s = rpfx(s, f"{pfx_src}norm2.", f"{tca}.LayerNorm_2.")
|
||||
s = rpfx(s, f"{pfx_src}attention.0.", f"{pca}.Concatenate.Chain_1.MultiheadAttention.")
|
||||
s = rpfx(s, f"{pfx_src}attention.1.", f"{pca}.Concatenate.Chain_2.MultiheadAttention.")
|
||||
s = rpfx(s, f"{pfx_src}attention.2.", f"{pca}.Concatenate.Chain_3.MultiheadAttention.")
|
||||
s = rpfx(s, f"{pfx_src}attention.3.", f"{pca}.Concatenate.Chain_4.MultiheadAttention.")
|
||||
s = rpfx(s, f"{pfx_src}sal_conv.", f"{pfx_dst}Parallel_2.Multiply.Chain.Conv2d.")
|
||||
return s
|
||||
|
||||
def cbr(s: str, pfx_src: str, pfx_dst: str, shift: int = 0) -> str:
|
||||
s = rpfx(s, f"{pfx_src}{shift}.", f"{pfx_dst}Conv2d.")
|
||||
s = rpfx(s, f"{pfx_src}{shift + 1}.", f"{pfx_dst}BatchNorm2d.")
|
||||
s = rpfx(s, f"{pfx_src}{shift + 2}.", f"{pfx_dst}PReLU.")
|
||||
return s
|
||||
|
||||
def cbg(s: str, pfx_src: str, pfx_dst: str) -> str:
|
||||
s = rpfx(s, f"{pfx_src}0.", f"{pfx_dst}Conv2d.")
|
||||
s = rpfx(s, f"{pfx_src}1.", f"{pfx_dst}BatchNorm2d.")
|
||||
return s
|
||||
|
||||
v = rpfx(v, "shallow.0.", "ComputeShallow.Conv2d.")
|
||||
|
||||
v = cbr(v, "output1.", "Pyramid.Sum.Chain.CBR.")
|
||||
v = cbr(v, "output2.", "Pyramid.Sum.PyramidL2.Sum.Chain.CBR.")
|
||||
v = cbr(v, "output3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.Chain.CBR.")
|
||||
v = cbr(v, "output4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.Chain.CBR.")
|
||||
v = cbr(v, "output5.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.CBR.")
|
||||
|
||||
v = cbr(v, "conv1.", "Pyramid.CBR.")
|
||||
v = cbr(v, "conv2.", "Pyramid.Sum.PyramidL2.CBR.")
|
||||
v = cbr(v, "conv3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.CBR.")
|
||||
v = cbr(v, "conv4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.CBR.")
|
||||
|
||||
v = mclm(v, "multifieldcrossatt.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.Sum.PyramidL5.MCLM.")
|
||||
|
||||
v = mcrm(v, "dec_blk1.", "Pyramid.MCRM.")
|
||||
v = mcrm(v, "dec_blk2.", "Pyramid.Sum.PyramidL2.MCRM.")
|
||||
v = mcrm(v, "dec_blk3.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.MCRM.")
|
||||
v = mcrm(v, "dec_blk4.", "Pyramid.Sum.PyramidL2.Sum.PyramidL3.Sum.PyramidL4.MCRM.")
|
||||
|
||||
v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_1.")
|
||||
v = cbr(v, "insmask_head.", "RearrangeMultiView.Chain.CBR_2.", shift=3)
|
||||
|
||||
v = rpfx(v, "insmask_head.6.", "RearrangeMultiView.Chain.Conv2d.")
|
||||
|
||||
v = cbg(v, "upsample1.", "ShallowUpscaler.Sum_2.Chain_1.CBG.")
|
||||
v = cbg(v, "upsample2.", "ShallowUpscaler.CBG.")
|
||||
|
||||
v = rpfx(v, "output.0.", "Conv2d.")
|
||||
|
||||
if v != k:
|
||||
keys_map[k] = v
|
||||
|
||||
for key, new_key in keys_map.items():
|
||||
state_dict[new_key] = state_dict[key]
|
||||
state_dict.pop(key)
|
||||
|
||||
return state_dict
|
211
src/refiners/foundationals/swin/mvanet/mclm.py
Normal file
211
src/refiners/foundationals/swin/mvanet/mclm.py
Normal file
|
@ -0,0 +1,211 @@
|
|||
# Multi-View Complementary Localization
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device as Device
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.context import Contexts
|
||||
|
||||
from .utils import FeedForward, MultiheadAttention, MultiPool, PatchMerge, PatchwiseCrossAttention, Unflatten
|
||||
|
||||
|
||||
class PerPixel(fl.Chain):
|
||||
"""(B, C, H, W) -> H*W, B, C"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
fl.Permute(2, 3, 0, 1),
|
||||
fl.Flatten(0, 1),
|
||||
)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(fl.Module):
|
||||
"""
|
||||
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats: int, device: Device | None = None):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
temperature = 10000
|
||||
self.dim_t = torch.arange(0, num_pos_feats, dtype=torch.float32, device=self.device)
|
||||
self.dim_t = temperature ** (2 * (self.dim_t // 2) / num_pos_feats)
|
||||
|
||||
def __call__(self, h: int, w: int) -> Tensor:
|
||||
mask = torch.ones([1, h, w, 1], dtype=torch.bool, device=self.device)
|
||||
y_embed = mask.cumsum(dim=1, dtype=torch.float32)
|
||||
x_embed = mask.cumsum(dim=2, dtype=torch.float32)
|
||||
|
||||
eps, scale = 1e-6, 2 * math.pi
|
||||
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * scale
|
||||
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * scale
|
||||
|
||||
pos_x = x_embed / self.dim_t
|
||||
pos_y = y_embed / self.dim_t
|
||||
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
return torch.cat((pos_y, pos_x), dim=3).permute(1, 2, 0, 3).flatten(0, 1)
|
||||
|
||||
|
||||
class MultiPoolPos(fl.Module):
|
||||
def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine):
|
||||
super().__init__()
|
||||
self.pool_ratios = pool_ratios
|
||||
self.positional_embedding = positional_embedding
|
||||
|
||||
def forward(self, *args: int) -> Tensor:
|
||||
h, w = args
|
||||
return torch.cat([self.positional_embedding(h // ratio, w // ratio) for ratio in self.pool_ratios])
|
||||
|
||||
|
||||
class Repeat(fl.Module):
|
||||
def __init__(self, dim: int = 0):
|
||||
self.dim = dim
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: Tensor, n: int) -> Tensor:
|
||||
return torch.repeat_interleave(x, n, dim=self.dim)
|
||||
|
||||
|
||||
class _MHA_Arg(fl.Sum):
|
||||
def __init__(self, offset: int):
|
||||
self.offset = offset
|
||||
super().__init__(
|
||||
fl.GetArg(offset), # value
|
||||
fl.Chain(
|
||||
fl.Parallel(
|
||||
fl.GetArg(self.offset + 1), # position embedding
|
||||
fl.Lambda(self._batch_size),
|
||||
),
|
||||
Repeat(1),
|
||||
),
|
||||
)
|
||||
|
||||
def _batch_size(self, *args: Tensor) -> int:
|
||||
return args[self.offset].size(1)
|
||||
|
||||
|
||||
class GlobalAttention(fl.Chain):
|
||||
# Input must be a 4-tuple: (global, global pos. emb, pools, pools pos. emb.)
|
||||
def __init__(
|
||||
self,
|
||||
emb_dim: int,
|
||||
num_heads: int = 1,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
fl.GetArg(0), # global
|
||||
fl.Chain(
|
||||
fl.Parallel(
|
||||
_MHA_Arg(0), # Q: global + pos. emb
|
||||
_MHA_Arg(2), # K: pools + pos. emb
|
||||
fl.GetArg(2), # V: pools
|
||||
),
|
||||
MultiheadAttention(emb_dim, num_heads, device=device),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MCLM(fl.Chain):
|
||||
"""Multi-View Complementary Localization Module
|
||||
Inputs:
|
||||
tensor: (b, 5, e, h, h)
|
||||
Outputs:
|
||||
tensor: (b, 5, e, h, h)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
emb_dim: int,
|
||||
num_heads: int = 1,
|
||||
pool_ratios: list[int] | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
if pool_ratios is None:
|
||||
pool_ratios = [2, 8, 16]
|
||||
|
||||
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device)
|
||||
|
||||
# LayerNorms in MCLM share their weights.
|
||||
|
||||
ln1 = fl.LayerNorm(emb_dim, device=device)
|
||||
ln2 = fl.LayerNorm(emb_dim, device=device)
|
||||
|
||||
def proxy(m: fl.Module) -> fl.Module:
|
||||
def f(x: Tensor) -> Tensor:
|
||||
return m(x)
|
||||
|
||||
return fl.Lambda(f)
|
||||
|
||||
super().__init__(
|
||||
fl.Parallel(
|
||||
fl.Chain( # global
|
||||
fl.Slicing(dim=1, start=4),
|
||||
fl.Squeeze(1),
|
||||
fl.Parallel(
|
||||
PerPixel(), # glb
|
||||
fl.Chain( # g_pos
|
||||
fl.Lambda(lambda x: x.shape[-2:]), # type: ignore
|
||||
positional_embedding,
|
||||
),
|
||||
),
|
||||
),
|
||||
fl.Chain( # local
|
||||
fl.Slicing(dim=1, end=4),
|
||||
fl.SetContext("mclm", "local"),
|
||||
PatchMerge(),
|
||||
fl.Parallel(
|
||||
fl.Chain( # pool
|
||||
MultiPool(pool_ratios),
|
||||
fl.Squeeze(0),
|
||||
),
|
||||
fl.Chain( # pool_pos
|
||||
fl.Lambda(lambda x: x.shape[-2:]), # type: ignore
|
||||
MultiPoolPos(pool_ratios, positional_embedding),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
fl.Lambda(lambda t1, t2: (*t1, *t2)), # type: ignore
|
||||
GlobalAttention(emb_dim, num_heads, device=device),
|
||||
ln1,
|
||||
FeedForward(emb_dim, device=device),
|
||||
ln2,
|
||||
fl.SetContext("mclm", "global"),
|
||||
fl.UseContext("mclm", "local"),
|
||||
fl.Flatten(-2, -1),
|
||||
fl.Permute(1, 3, 0, 2),
|
||||
fl.Residual(
|
||||
fl.Parallel(
|
||||
fl.Identity(),
|
||||
fl.Chain(
|
||||
fl.UseContext("mclm", "global"),
|
||||
Unflatten(0, (2, 8, 2, 8)), # 2, h/2, 2, h/2
|
||||
fl.Permute(0, 2, 1, 3, 4, 5),
|
||||
fl.Flatten(0, 1),
|
||||
fl.Flatten(1, 2),
|
||||
),
|
||||
),
|
||||
PatchwiseCrossAttention(emb_dim, num_heads, device=device),
|
||||
),
|
||||
proxy(ln1),
|
||||
FeedForward(emb_dim, device=device),
|
||||
proxy(ln2),
|
||||
fl.Concatenate(
|
||||
fl.Identity(),
|
||||
fl.Chain(
|
||||
fl.UseContext("mclm", "global"),
|
||||
fl.Unsqueeze(0),
|
||||
),
|
||||
),
|
||||
Unflatten(1, (16, 16)), # h, h
|
||||
fl.Permute(3, 0, 4, 1, 2),
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"mclm": {"global": None, "local": None}}
|
119
src/refiners/foundationals/swin/mvanet/mcrm.py
Normal file
119
src/refiners/foundationals/swin/mvanet/mcrm.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
# Multi-View Complementary Refinement
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device as Device
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
from .utils import FeedForward, Interpolate, MultiPool, PatchMerge, PatchSplit, PatchwiseCrossAttention, Unflatten
|
||||
|
||||
|
||||
class Multiply(fl.Chain):
|
||||
def __init__(self, o1: fl.Module, o2: fl.Module) -> None:
|
||||
super().__init__(o1, o2)
|
||||
|
||||
def forward(self, *args: Tensor) -> Tensor:
|
||||
return torch.mul(self[0](*args), self[1](*args))
|
||||
|
||||
|
||||
class TiledCrossAttention(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
emb_dim: int,
|
||||
dim: int,
|
||||
num_heads: int = 1,
|
||||
pool_ratios: list[int] | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
# Input must be a 4-tuple: (local, global)
|
||||
|
||||
if pool_ratios is None:
|
||||
pool_ratios = [1, 2, 4]
|
||||
|
||||
super().__init__(
|
||||
fl.Distribute(
|
||||
fl.Chain( # local
|
||||
fl.Flatten(-2, -1),
|
||||
fl.Permute(1, 3, 0, 2),
|
||||
),
|
||||
fl.Chain( # global
|
||||
PatchSplit(),
|
||||
fl.Squeeze(0),
|
||||
MultiPool(pool_ratios),
|
||||
),
|
||||
),
|
||||
fl.Sum(
|
||||
fl.Chain(
|
||||
fl.GetArg(0),
|
||||
fl.Permute(2, 1, 0, 3),
|
||||
),
|
||||
fl.Chain(
|
||||
PatchwiseCrossAttention(emb_dim, num_heads, device=device),
|
||||
fl.Permute(2, 1, 0, 3),
|
||||
),
|
||||
),
|
||||
fl.LayerNorm(emb_dim, device=device),
|
||||
FeedForward(emb_dim, device=device),
|
||||
fl.LayerNorm(emb_dim, device=device),
|
||||
fl.Permute(0, 2, 3, 1),
|
||||
Unflatten(-1, (dim, dim)),
|
||||
)
|
||||
|
||||
|
||||
class MCRM(fl.Chain):
|
||||
"""Multi-View Complementary Refinement"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
emb_dim: int,
|
||||
size: int,
|
||||
num_heads: int = 1,
|
||||
pool_ratios: list[int] | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
if pool_ratios is None:
|
||||
pool_ratios = [1, 2, 4]
|
||||
|
||||
super().__init__(
|
||||
fl.Parallel(
|
||||
fl.Chain( # local
|
||||
fl.Slicing(dim=1, end=4),
|
||||
),
|
||||
fl.Chain( # global
|
||||
fl.Slicing(dim=1, start=4),
|
||||
fl.Squeeze(1),
|
||||
),
|
||||
),
|
||||
fl.Parallel(
|
||||
Multiply(
|
||||
fl.GetArg(0),
|
||||
fl.Chain(
|
||||
fl.GetArg(1),
|
||||
fl.Conv2d(emb_dim, 1, 1, device=device),
|
||||
fl.Sigmoid(),
|
||||
Interpolate((size * 2, size * 2), "nearest"),
|
||||
PatchSplit(),
|
||||
),
|
||||
),
|
||||
fl.GetArg(1),
|
||||
),
|
||||
fl.Parallel(
|
||||
TiledCrossAttention(emb_dim, size, num_heads, pool_ratios, device=device),
|
||||
fl.GetArg(1),
|
||||
),
|
||||
fl.Concatenate(
|
||||
fl.GetArg(0),
|
||||
fl.Chain(
|
||||
fl.Sum(
|
||||
fl.GetArg(1),
|
||||
fl.Chain(
|
||||
fl.GetArg(0),
|
||||
PatchMerge(),
|
||||
Interpolate((size, size), "nearest"),
|
||||
),
|
||||
),
|
||||
fl.Unsqueeze(1),
|
||||
),
|
||||
dim=1,
|
||||
),
|
||||
)
|
337
src/refiners/foundationals/swin/mvanet/mvanet.py
Normal file
337
src/refiners/foundationals/swin/mvanet/mvanet.py
Normal file
|
@ -0,0 +1,337 @@
|
|||
# Multi-View Aggregation Network (arXiv:2404.07445)
|
||||
|
||||
from torch import device as Device
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.foundationals.swin.swin_transformer import SwinTransformer
|
||||
|
||||
from .mclm import MCLM # Multi-View Complementary Localization
|
||||
from .mcrm import MCRM # Multi-View Complementary Refinement
|
||||
from .utils import BatchNorm2d, Interpolate, PatchMerge, PatchSplit, PReLU, Rescale, Unflatten
|
||||
|
||||
|
||||
class CBG(fl.Chain):
|
||||
"""(C)onvolution + (B)atchNorm + (G)eLU"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
out_dim = out_dim or in_dim
|
||||
super().__init__(
|
||||
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
|
||||
BatchNorm2d(out_dim, device=device),
|
||||
fl.GeLU(),
|
||||
)
|
||||
|
||||
|
||||
class CBR(fl.Chain):
|
||||
"""(C)onvolution + (B)atchNorm + Parametric (R)eLU"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
out_dim = out_dim or in_dim
|
||||
super().__init__(
|
||||
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
|
||||
BatchNorm2d(out_dim, device=device),
|
||||
PReLU(device=device),
|
||||
)
|
||||
|
||||
|
||||
class SplitMultiView(fl.Chain):
|
||||
"""
|
||||
Split a hd tensor into 5 ld views, (5 = 1 global + 4 tiles)
|
||||
See also the reverse Module [`RearrangeMultiView`][refiners.foundationals.swin.mvanet.RearrangeMultiView]
|
||||
|
||||
Inputs:
|
||||
single_view (b, c, H, W)
|
||||
|
||||
Outputs:
|
||||
multi_view (b, 5, c, H/2, W/2)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
fl.Concatenate(
|
||||
PatchSplit(), # global features
|
||||
fl.Chain( # local features
|
||||
Rescale(scale_factor=0.5, mode="bilinear"),
|
||||
fl.Unsqueeze(1),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ShallowUpscaler(fl.Chain):
|
||||
"""4x Upscaler reusing the image as input to upscale the feature
|
||||
See [[arXiv:2108.10257] SwinIR: Image Restoration Using Swin Transformer](https://arxiv.org/abs/2108.10257)
|
||||
|
||||
Args:
|
||||
embedding_dim (int): the embedding dimension
|
||||
|
||||
Inputs:
|
||||
feature (b, E, image_size/4, image_size/4)
|
||||
|
||||
Output:
|
||||
upscaled tensor (b, E, image_size, image_size)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
fl.Identity(),
|
||||
fl.Chain(
|
||||
fl.UseContext("mvanet", "shallow"),
|
||||
Interpolate((256, 256)),
|
||||
),
|
||||
),
|
||||
fl.Sum(
|
||||
fl.Chain(
|
||||
Rescale(2),
|
||||
CBG(embedding_dim, device=device),
|
||||
),
|
||||
fl.Chain(
|
||||
fl.UseContext("mvanet", "shallow"),
|
||||
Interpolate((512, 512)),
|
||||
),
|
||||
),
|
||||
Rescale(2),
|
||||
CBG(embedding_dim, device=device),
|
||||
)
|
||||
|
||||
|
||||
class PyramidL5(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.GetArg(0), # output5
|
||||
fl.Flatten(0, 1),
|
||||
CBR(1024, embedding_dim, device=device),
|
||||
Unflatten(0, (-1, 5)),
|
||||
MCLM(embedding_dim, device=device),
|
||||
fl.Flatten(0, 1),
|
||||
Interpolate((32, 32)),
|
||||
)
|
||||
|
||||
|
||||
class PyramidL4(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
PyramidL5(embedding_dim=embedding_dim, device=device),
|
||||
fl.Chain(
|
||||
fl.GetArg(1),
|
||||
fl.Flatten(0, 1),
|
||||
CBR(512, embedding_dim, device=device), # output4
|
||||
Unflatten(0, (-1, 5)),
|
||||
),
|
||||
),
|
||||
MCRM(embedding_dim, 32, device=device), # dec_blk4
|
||||
fl.Flatten(0, 1),
|
||||
CBR(embedding_dim, device=device), # conv4
|
||||
Interpolate((64, 64)),
|
||||
)
|
||||
|
||||
|
||||
class PyramidL3(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
PyramidL4(embedding_dim=embedding_dim, device=device),
|
||||
fl.Chain(
|
||||
fl.GetArg(2),
|
||||
fl.Flatten(0, 1),
|
||||
CBR(256, embedding_dim, device=device), # output3
|
||||
Unflatten(0, (-1, 5)),
|
||||
),
|
||||
),
|
||||
MCRM(embedding_dim, 64, device=device), # dec_blk3
|
||||
fl.Flatten(0, 1),
|
||||
CBR(embedding_dim, device=device), # conv3
|
||||
Interpolate((128, 128)),
|
||||
)
|
||||
|
||||
|
||||
class PyramidL2(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
embedding_dim = 128
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
PyramidL3(embedding_dim=embedding_dim, device=device),
|
||||
fl.Chain(
|
||||
fl.GetArg(3),
|
||||
fl.Flatten(0, 1),
|
||||
CBR(128, embedding_dim, device=device), # output2
|
||||
Unflatten(0, (-1, 5)),
|
||||
),
|
||||
),
|
||||
MCRM(embedding_dim, 128, device=device), # dec_blk2
|
||||
fl.Flatten(0, 1),
|
||||
CBR(embedding_dim, device=device), # conv2
|
||||
Interpolate((128, 128)),
|
||||
)
|
||||
|
||||
|
||||
class Pyramid(fl.Chain):
|
||||
"""
|
||||
Recursive Pyramidal Network calling MCLM and MCRM blocks
|
||||
|
||||
It acts as a FPN (Feature Pyramid Network) Neck for MVANet
|
||||
see [[arXiv:1612.03144] Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
|
||||
|
||||
Inputs:
|
||||
features: a pyramid of N = 5 tensors
|
||||
shapes are (b, 5, E_{0}, S_{0}, S_{0}), ..., (b, 5, E_{1}, S_{i}, S_{i}), ..., (b, 5, E_{N-1}, S_{N-1}, S_{N-1})
|
||||
with S_{i} = S_{i-1} or S_{i} = 2*S_{i-1} for 0 < i < N
|
||||
|
||||
Outputs:
|
||||
output (b, 5, E, S_{N-1}, S_{N-1})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
PyramidL2(embedding_dim=embedding_dim, device=device),
|
||||
fl.Chain(
|
||||
fl.GetArg(4),
|
||||
fl.Flatten(0, 1),
|
||||
CBR(128, embedding_dim, device=device), # output1
|
||||
Unflatten(0, (-1, 5)),
|
||||
),
|
||||
),
|
||||
MCRM(embedding_dim, 128, device=device), # dec_blk1
|
||||
fl.Flatten(0, 1),
|
||||
CBR(embedding_dim, device=device), # conv1
|
||||
Unflatten(0, (-1, 5)),
|
||||
)
|
||||
|
||||
|
||||
class RearrangeMultiView(fl.Chain):
|
||||
"""
|
||||
Inputs:
|
||||
multi_view (b, 5, E, H, W)
|
||||
|
||||
Outputs:
|
||||
single_view (b, E, H*2, W*2)
|
||||
|
||||
Fusion a multi view tensor into a single view tensor, using convolutions
|
||||
See also the reverse Module [`SplitMultiView`][refiners.foundationals.swin.mvanet.SplitMultiView]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
fl.Chain( # local features
|
||||
fl.Slicing(dim=1, end=4),
|
||||
PatchMerge(),
|
||||
),
|
||||
fl.Chain( # global feature
|
||||
fl.Slicing(dim=1, start=4),
|
||||
fl.Squeeze(1),
|
||||
Interpolate((256, 256)),
|
||||
),
|
||||
),
|
||||
fl.Chain( # conv head
|
||||
CBR(embedding_dim, 384, device=device),
|
||||
CBR(384, device=device),
|
||||
fl.Conv2d(384, embedding_dim, kernel_size=3, padding=1, device=device),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ComputeShallow(fl.Passthrough):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device),
|
||||
fl.SetContext("mvanet", "shallow"),
|
||||
)
|
||||
|
||||
|
||||
class MVANet(fl.Chain):
|
||||
"""Multi-view Aggregation Network for Dichotomous Image Segmentation
|
||||
|
||||
See [[arXiv:2404.07445] Multi-view Aggregation Network for Dichotomous Image Segmentation](https://arxiv.org/abs/2404.07445) for more details.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): embedding dimension
|
||||
n_logits (int): the number of output logits (default to 1)
|
||||
1 logit is used for alpha matting/foreground-background segmentation/sod segmentation
|
||||
depths (list[int]): see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer]
|
||||
num_heads (list[int]): see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer]
|
||||
window_size (int): default to 12, see [`SwinTransformer`][refiners.foundationals.swin.swin_transformer.SwinTransformer]
|
||||
device (Device | None): the device to use
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int = 128,
|
||||
n_logits: int = 1,
|
||||
depths: list[int] | None = None,
|
||||
num_heads: list[int] | None = None,
|
||||
window_size: int = 12,
|
||||
device: Device | None = None,
|
||||
):
|
||||
if depths is None:
|
||||
depths = [2, 2, 18, 2]
|
||||
if num_heads is None:
|
||||
num_heads = [4, 8, 16, 32]
|
||||
|
||||
super().__init__(
|
||||
ComputeShallow(embedding_dim=embedding_dim, device=device),
|
||||
SplitMultiView(),
|
||||
fl.Flatten(0, 1),
|
||||
SwinTransformer(
|
||||
embedding_dim=embedding_dim,
|
||||
depths=depths,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
device=device,
|
||||
),
|
||||
fl.Distribute(*(Unflatten(0, (-1, 5)) for _ in range(5))),
|
||||
Pyramid(embedding_dim=embedding_dim, device=device),
|
||||
RearrangeMultiView(embedding_dim=embedding_dim, device=device),
|
||||
ShallowUpscaler(embedding_dim, device=device),
|
||||
fl.Conv2d(embedding_dim, n_logits, kernel_size=3, padding=1, device=device),
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"mvanet": {"shallow": None}}
|
173
src/refiners/foundationals/swin/mvanet/utils.py
Normal file
173
src/refiners/foundationals/swin/mvanet/utils.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
import torch
|
||||
from torch import Size, Tensor
|
||||
from torch.nn.functional import (
|
||||
adaptive_avg_pool2d,
|
||||
interpolate, # type: ignore
|
||||
)
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
class Unflatten(fl.Module):
|
||||
def __init__(self, dim: int, sizes: tuple[int, ...]) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.sizes = Size(sizes)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return torch.unflatten(input=x, dim=self.dim, sizes=self.sizes)
|
||||
|
||||
|
||||
class Interpolate(fl.Module):
|
||||
def __init__(self, size: tuple[int, ...], mode: str = "bilinear"):
|
||||
super().__init__()
|
||||
self.size = Size(size)
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return interpolate(x, size=self.size, mode=self.mode) # type: ignore
|
||||
|
||||
|
||||
class Rescale(fl.Module):
|
||||
def __init__(self, scale_factor: float, mode: str = "nearest"):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return interpolate(x, scale_factor=self.scale_factor, mode=self.mode) # type: ignore
|
||||
|
||||
|
||||
class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule):
|
||||
def __init__(self, num_features: int, device: torch.device | None = None):
|
||||
super().__init__(num_features=num_features, device=device) # type: ignore
|
||||
|
||||
|
||||
class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation):
|
||||
def __init__(self, device: torch.device | None = None):
|
||||
super().__init__(device=device) # type: ignore
|
||||
|
||||
|
||||
class PatchSplit(fl.Chain):
|
||||
"""(B, N, H, W) -> B, 4, N, H/2, W/2"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
Unflatten(-2, (2, -1)),
|
||||
Unflatten(-1, (2, -1)),
|
||||
fl.Permute(0, 2, 4, 1, 3, 5),
|
||||
fl.Flatten(1, 2),
|
||||
)
|
||||
|
||||
|
||||
class PatchMerge(fl.Chain):
|
||||
"""B, 4, N, H, W -> (B, N, 2*H, 2*W)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
Unflatten(1, (2, 2)),
|
||||
fl.Permute(0, 3, 1, 4, 2, 5),
|
||||
fl.Flatten(-2, -1),
|
||||
fl.Flatten(-3, -2),
|
||||
)
|
||||
|
||||
|
||||
class FeedForward(fl.Residual):
|
||||
def __init__(self, emb_dim: int, device: torch.device | None = None) -> None:
|
||||
super().__init__(
|
||||
fl.Linear(in_features=emb_dim, out_features=2 * emb_dim, device=device),
|
||||
fl.ReLU(),
|
||||
fl.Linear(in_features=2 * emb_dim, out_features=emb_dim, device=device),
|
||||
)
|
||||
|
||||
|
||||
class _GetArgs(fl.Parallel):
|
||||
def __init__(self, n: int):
|
||||
super().__init__(
|
||||
fl.Chain(
|
||||
fl.GetArg(0),
|
||||
fl.Slicing(dim=0, start=n, end=n + 1),
|
||||
fl.Squeeze(0),
|
||||
),
|
||||
fl.Chain(
|
||||
fl.GetArg(1),
|
||||
fl.Slicing(dim=0, start=n, end=n + 1),
|
||||
fl.Squeeze(0),
|
||||
),
|
||||
fl.Chain(
|
||||
fl.GetArg(1),
|
||||
fl.Slicing(dim=0, start=n, end=n + 1),
|
||||
fl.Squeeze(0),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule):
|
||||
def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None):
|
||||
super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore
|
||||
|
||||
@property
|
||||
def weight(self) -> Tensor: # type: ignore
|
||||
return self.in_proj_weight
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # type: ignore
|
||||
return super().forward(q, k, v)[0]
|
||||
|
||||
|
||||
class PatchwiseCrossAttention(fl.Chain):
|
||||
# Input is 2 tensors of sizes (4, HW, B, C) and (4, HW', B, C),
|
||||
# output is size (4, HW, B, C).
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
num_heads: int,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Concatenate(
|
||||
fl.Chain(
|
||||
_GetArgs(0),
|
||||
MultiheadAttention(d_model, num_heads, device=device),
|
||||
),
|
||||
fl.Chain(
|
||||
_GetArgs(1),
|
||||
MultiheadAttention(d_model, num_heads, device=device),
|
||||
),
|
||||
fl.Chain(
|
||||
_GetArgs(2),
|
||||
MultiheadAttention(d_model, num_heads, device=device),
|
||||
),
|
||||
fl.Chain(
|
||||
_GetArgs(3),
|
||||
MultiheadAttention(d_model, num_heads, device=device),
|
||||
),
|
||||
),
|
||||
Unflatten(0, (4, -1)),
|
||||
)
|
||||
|
||||
|
||||
class Pool(fl.Module):
|
||||
def __init__(self, ratio: int) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
b, _, h, w = x.shape
|
||||
assert h % self.ratio == 0 and w % self.ratio == 0
|
||||
r = adaptive_avg_pool2d(x, (h // self.ratio, w // self.ratio))
|
||||
return torch.unflatten(r, 0, (b, -1))
|
||||
|
||||
|
||||
class MultiPool(fl.Concatenate):
|
||||
def __init__(self, pool_ratios: list[int]) -> None:
|
||||
super().__init__(
|
||||
*(
|
||||
fl.Chain(
|
||||
Pool(pool_ratio),
|
||||
fl.Flatten(-2, -1),
|
||||
fl.Permute(0, 3, 1, 2),
|
||||
)
|
||||
for pool_ratio in pool_ratios
|
||||
),
|
||||
dim=1,
|
||||
)
|
391
src/refiners/foundationals/swin/swin_transformer.py
Normal file
391
src/refiners/foundationals/swin/swin_transformer.py
Normal file
|
@ -0,0 +1,391 @@
|
|||
# Swin Transformer (arXiv:2103.14030)
|
||||
#
|
||||
# Specific to MVANet, only supports square inputs.
|
||||
# Originally adapted from the version in MVANet and InSPyReNet (https://github.com/plemeri/InSPyReNet)
|
||||
# Original implementation by Microsoft at https://github.com/microsoft/Swin-Transformer
|
||||
|
||||
import functools
|
||||
from math import isqrt
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device as Device
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.context import Contexts
|
||||
|
||||
|
||||
def to_windows(x: Tensor, window_size: int) -> Tensor:
|
||||
B, H, W, C = x.shape
|
||||
assert W == H and H % window_size == 0
|
||||
x = x.reshape(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
return x.permute(0, 1, 3, 2, 4, 5).reshape(B, -1, window_size * window_size, C)
|
||||
|
||||
|
||||
class ToWindows(fl.Module):
|
||||
def __init__(self, window_size: int):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return to_windows(x, self.window_size)
|
||||
|
||||
|
||||
class FromWindows(fl.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, num_windows, window_size_2, C = x.shape
|
||||
window_size = isqrt(window_size_2)
|
||||
H = isqrt(num_windows * window_size_2)
|
||||
x = x.reshape(B, H // window_size, H // window_size, window_size, window_size, C)
|
||||
return x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, H, C)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_attn_mask(H: int, window_size: int, device: Device | None = None) -> Tensor:
|
||||
assert H % window_size == 0
|
||||
shift_size = window_size // 2
|
||||
img_mask = torch.zeros((1, H, H, 1), device=device)
|
||||
h_slices = (
|
||||
slice(0, -window_size),
|
||||
slice(-window_size, -shift_size),
|
||||
slice(-shift_size, None),
|
||||
)
|
||||
w_slices = (
|
||||
slice(0, -window_size),
|
||||
slice(-window_size, -shift_size),
|
||||
slice(-shift_size, None),
|
||||
)
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = to_windows(img_mask, window_size).squeeze() # B, nW, window_size * window_size, [1]
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask.masked_fill_(attn_mask != 0, -100.0).masked_fill_(attn_mask == 0, 0.0)
|
||||
return attn_mask
|
||||
|
||||
|
||||
class Pad(fl.Module):
|
||||
def __init__(self, step: int):
|
||||
super().__init__()
|
||||
self.step = step
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, H, W, C = x.shape
|
||||
assert W == H
|
||||
if H % self.step == 0:
|
||||
return x
|
||||
p = self.step * ((H + self.step - 1) // self.step)
|
||||
padded = torch.zeros(B, p, p, C, device=x.device, dtype=x.dtype)
|
||||
padded[:, :H, :H, :] = x
|
||||
return padded
|
||||
|
||||
|
||||
class StatefulPad(fl.Chain):
|
||||
def __init__(self, context: str, key: str, step: int) -> None:
|
||||
super().__init__(
|
||||
fl.SetContext(context=context, key=key, callback=self._push),
|
||||
Pad(step=step),
|
||||
)
|
||||
|
||||
def _push(self, sizes: list[int], x: Tensor) -> None:
|
||||
sizes.append(x.size(1))
|
||||
|
||||
|
||||
class StatefulUnpad(fl.Chain):
|
||||
def __init__(self, context: str, key: str) -> None:
|
||||
super().__init__(
|
||||
fl.Parallel(
|
||||
fl.Identity(),
|
||||
fl.UseContext(context=context, key=key).compose(lambda x: x.pop()),
|
||||
),
|
||||
fl.Lambda(self._unpad),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _unpad(x: Tensor, size: int) -> Tensor:
|
||||
return x[:, :size, :size, :]
|
||||
|
||||
|
||||
class SquareUnflatten(fl.Module):
|
||||
# ..., L^2, ... -> ..., L, L, ...
|
||||
|
||||
def __init__(self, dim: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
d = isqrt(x.shape[self.dim])
|
||||
return torch.unflatten(x, self.dim, (d, d))
|
||||
|
||||
|
||||
class WindowUnflatten(fl.Module):
|
||||
# ..., H, ... -> ..., H // ws, ws, ...
|
||||
|
||||
def __init__(self, window_size: int, dim: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.dim] % self.window_size == 0
|
||||
H = x.shape[self.dim]
|
||||
return torch.unflatten(x, self.dim, (H // self.window_size, self.window_size))
|
||||
|
||||
|
||||
class Roll(fl.Module):
|
||||
def __init__(self, *shifts: tuple[int, int]):
|
||||
super().__init__()
|
||||
self.shifts = shifts
|
||||
self._dims = tuple(s[0] for s in shifts)
|
||||
self._shifts = tuple(s[1] for s in shifts)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return torch.roll(x, self._shifts, self._dims)
|
||||
|
||||
|
||||
class RelativePositionBias(fl.Module):
|
||||
relative_position_index: Tensor
|
||||
|
||||
def __init__(self, window_size: int, num_heads: int, device: Device | None = None):
|
||||
super().__init__()
|
||||
self.relative_position_bias_table = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
(2 * window_size - 1) * (2 * window_size - 1),
|
||||
num_heads,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
relative_position_index = torch.empty(
|
||||
window_size**2,
|
||||
window_size**2,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
def forward(self) -> Tensor:
|
||||
# Yes, this is a (trainable) constant.
|
||||
return self.relative_position_bias_table[self.relative_position_index].permute(2, 0, 1).unsqueeze(0)
|
||||
|
||||
|
||||
class WindowSDPA(fl.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
window_size: int,
|
||||
num_heads: int,
|
||||
shift: bool = False,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_heads = num_heads
|
||||
self.shift = shift
|
||||
self.rpb = RelativePositionBias(window_size, num_heads, device=device)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
B, num_windows, N, _C = x.shape
|
||||
assert _C % (3 * self.num_heads) == 0
|
||||
C = _C // 3
|
||||
x = torch.reshape(x, (B * num_windows, N, 3, self.num_heads, C // self.num_heads))
|
||||
q, k, v = x.permute(2, 0, 3, 1, 4)
|
||||
|
||||
attn_mask = self.rpb()
|
||||
if self.shift:
|
||||
mask = get_attn_mask(isqrt(num_windows * (self.window_size**2)), self.window_size, x.device)
|
||||
mask = mask.reshape(1, num_windows, 1, N, N)
|
||||
mask = mask.expand(B, -1, self.num_heads, -1, -1)
|
||||
attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
x = x.transpose(1, 2).reshape(B, num_windows, N, C)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(fl.Chain):
|
||||
"""
|
||||
Window-based Multi-head Self-Attenion (W-MSA), optionally shifted (SW-MSA).
|
||||
|
||||
It has a trainable relative position bias (RelativePositionBias).
|
||||
|
||||
The input projection is stored as a single Linear for q, k and v.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
window_size: int,
|
||||
num_heads: int,
|
||||
shift: bool = False,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Linear(dim, dim * 3, bias=True, device=device),
|
||||
WindowSDPA(dim, window_size, num_heads, shift, device=device),
|
||||
fl.Linear(dim, dim, device=device),
|
||||
)
|
||||
|
||||
|
||||
class SwinTransformerBlock(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: int = 7,
|
||||
shift_size: int = 0,
|
||||
mlp_ratio: float = 4.0,
|
||||
device: Device | None = None,
|
||||
):
|
||||
assert 0 <= shift_size < window_size, "shift_size must in [0, window_size["
|
||||
|
||||
super().__init__(
|
||||
fl.Residual(
|
||||
fl.LayerNorm(dim, device=device),
|
||||
SquareUnflatten(1),
|
||||
StatefulPad(context="padding", key="sizes", step=window_size),
|
||||
Roll((1, -shift_size), (2, -shift_size)),
|
||||
ToWindows(window_size),
|
||||
WindowAttention(
|
||||
dim,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
shift=shift_size > 0,
|
||||
device=device,
|
||||
),
|
||||
FromWindows(),
|
||||
Roll((1, shift_size), (2, shift_size)),
|
||||
StatefulUnpad(context="padding", key="sizes"),
|
||||
fl.Flatten(1, 2),
|
||||
),
|
||||
fl.Residual(
|
||||
fl.LayerNorm(dim, device=device),
|
||||
fl.Linear(dim, int(dim * mlp_ratio), device=device),
|
||||
fl.GeLU(),
|
||||
fl.Linear(int(dim * mlp_ratio), dim, device=device),
|
||||
),
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"padding": {"sizes": []}}
|
||||
|
||||
|
||||
class PatchMerging(fl.Chain):
|
||||
def __init__(self, dim: int, device: Device | None = None):
|
||||
super().__init__(
|
||||
SquareUnflatten(1),
|
||||
Pad(2),
|
||||
WindowUnflatten(2, 2),
|
||||
WindowUnflatten(2, 1),
|
||||
fl.Permute(0, 1, 3, 4, 2, 5),
|
||||
fl.Flatten(3),
|
||||
fl.Flatten(1, 2),
|
||||
fl.LayerNorm(4 * dim, device=device),
|
||||
fl.Linear(4 * dim, 2 * dim, bias=False, device=device),
|
||||
)
|
||||
|
||||
|
||||
class BasicLayer(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
window_size: int = 7,
|
||||
mlp_ratio: float = 4.0,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device,
|
||||
)
|
||||
for i in range(depth)
|
||||
)
|
||||
|
||||
|
||||
class PatchEmbedding(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: tuple[int, int] = (4, 4),
|
||||
in_chans: int = 3,
|
||||
embedding_dim: int = 96,
|
||||
device: Device | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device),
|
||||
fl.Flatten(2),
|
||||
fl.Transpose(1, 2),
|
||||
fl.LayerNorm(embedding_dim, device=device),
|
||||
)
|
||||
|
||||
|
||||
class SwinTransformer(fl.Chain):
|
||||
"""Swin Transformer (arXiv:2103.14030)
|
||||
|
||||
Currently specific to MVANet, only supports square inputs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: tuple[int, int] = (4, 4),
|
||||
in_chans: int = 3,
|
||||
embedding_dim: int = 96,
|
||||
depths: list[int] | None = None,
|
||||
num_heads: list[int] | None = None,
|
||||
window_size: int = 7, # image size is 32 * this
|
||||
mlp_ratio: float = 4.0,
|
||||
device: Device | None = None,
|
||||
):
|
||||
if depths is None:
|
||||
depths = [2, 2, 6, 2]
|
||||
|
||||
if num_heads is None:
|
||||
num_heads = [3, 6, 12, 24]
|
||||
|
||||
self.num_layers = len(depths)
|
||||
assert len(num_heads) == self.num_layers
|
||||
|
||||
super().__init__(
|
||||
PatchEmbedding(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embedding_dim=embedding_dim,
|
||||
device=device,
|
||||
),
|
||||
fl.Passthrough(
|
||||
fl.Transpose(1, 2),
|
||||
SquareUnflatten(2),
|
||||
fl.SetContext("swin", "outputs", callback=lambda t, x: t.append(x)),
|
||||
),
|
||||
*(
|
||||
fl.Chain(
|
||||
BasicLayer(
|
||||
dim=int(embedding_dim * 2**i),
|
||||
depth=depths[i],
|
||||
num_heads=num_heads[i],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device,
|
||||
),
|
||||
fl.Passthrough(
|
||||
fl.LayerNorm(int(embedding_dim * 2**i), device=device),
|
||||
fl.Transpose(1, 2),
|
||||
SquareUnflatten(2),
|
||||
fl.SetContext("swin", "outputs", callback=lambda t, x: t.insert(0, x)),
|
||||
),
|
||||
PatchMerging(dim=int(embedding_dim * 2**i), device=device)
|
||||
if i < self.num_layers - 1
|
||||
else fl.UseContext("swin", "outputs").compose(lambda t: tuple(t)),
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
),
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"swin": {"outputs": []}}
|
59
tests/e2e/test_mvanet.py
Normal file
59
tests/e2e/test_mvanet.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, no_grad, normalize, tensor_to_image
|
||||
from refiners.foundationals.swin.mvanet import MVANet
|
||||
|
||||
|
||||
def _img_open(path: Path) -> Image.Image:
|
||||
return Image.open(path) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_path(test_e2e_path: Path) -> Path:
|
||||
return test_e2e_path / "test_mvanet_ref"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_cactus(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "cactus.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_cactus_mask(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_cactus_mask.png")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mvanet_weights(test_weights_path: Path) -> Path:
|
||||
weights = test_weights_path / "mvanet" / "mvanet.safetensors"
|
||||
if not weights.is_file():
|
||||
warn(f"could not find weights at {test_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
return weights
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mvanet_model(mvanet_weights: Path, test_device: torch.device) -> MVANet:
|
||||
model = MVANet(device=test_device).eval() # .eval() is important!
|
||||
model.load_from_safetensors(mvanet_weights)
|
||||
return model
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_mvanet(
|
||||
mvanet_model: MVANet,
|
||||
ref_cactus: Image.Image,
|
||||
expected_cactus_mask: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
in_t = image_to_tensor(ref_cactus.resize((1024, 1024), Image.Resampling.BILINEAR)).squeeze()
|
||||
in_t = normalize(in_t, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
|
||||
prediction: torch.Tensor = mvanet_model(in_t.to(test_device)).sigmoid()
|
||||
cactus_mask = tensor_to_image(prediction).resize(ref_cactus.size, Image.Resampling.BILINEAR)
|
||||
ensure_similar_images(cactus_mask.convert("RGB"), expected_cactus_mask.convert("RGB"))
|
3
tests/e2e/test_mvanet_ref/README.md
Normal file
3
tests/e2e/test_mvanet_ref/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
`cactus.png` is cropped from this image: https://www.freepik.com/free-photo/laptop-notebook-pen-coffee-cup-plants-wooden-desk_269339828.htm
|
||||
|
||||
`expected_cactus_mask.png` has been generated using the [official MVANet codebase](https://github.com/qianyu-dlut/MVANet) and weights.
|
BIN
tests/e2e/test_mvanet_ref/cactus.png
Normal file
BIN
tests/e2e/test_mvanet_ref/cactus.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 96 KiB |
BIN
tests/e2e/test_mvanet_ref/expected_cactus_mask.png
Normal file
BIN
tests/e2e/test_mvanet_ref/expected_cactus_mask.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.8 KiB |
1
typings/gdown/__init__.pyi
Normal file
1
typings/gdown/__init__.pyi
Normal file
|
@ -0,0 +1 @@
|
|||
def download(id: str, output: str, quiet: bool = False) -> str: ...
|
Loading…
Reference in a new issue