Skip to content

Swin Transformers

SwinTransformer

SwinTransformer(
    patch_size: tuple[int, int] = (4, 4),
    in_chans: int = 3,
    embedding_dim: int = 96,
    depths: list[int] | None = None,
    num_heads: list[int] | None = None,
    window_size: int = 7,
    mlp_ratio: float = 4.0,
    device: device | None = None,
)

Bases: Chain

Swin Transformer (arXiv:2103.14030)

Currently specific to MVANet, only supports square inputs.

Source code in src/refiners/foundationals/swin/swin_transformer.py
def __init__(
    self,
    patch_size: tuple[int, int] = (4, 4),
    in_chans: int = 3,
    embedding_dim: int = 96,
    depths: list[int] | None = None,
    num_heads: list[int] | None = None,
    window_size: int = 7,  # image size is 32 * this
    mlp_ratio: float = 4.0,
    device: Device | None = None,
) -> None:
    if depths is None:
        depths = [2, 2, 6, 2]

    if num_heads is None:
        num_heads = [3, 6, 12, 24]

    self.num_layers = len(depths)
    assert len(num_heads) == self.num_layers

    super().__init__(
        PatchEmbedding(
            patch_size=patch_size,
            in_chans=in_chans,
            embedding_dim=embedding_dim,
            device=device,
        ),
        fl.Passthrough(
            fl.Transpose(1, 2),
            SquareUnflatten(2),
            fl.SetContext("swin", "outputs", callback=lambda t, x: t.append(x)),
        ),
        *(
            fl.Chain(
                BasicLayer(
                    dim=int(embedding_dim * 2**i),
                    depth=depths[i],
                    num_heads=num_heads[i],
                    window_size=window_size,
                    mlp_ratio=mlp_ratio,
                    device=device,
                ),
                fl.Passthrough(
                    fl.LayerNorm(int(embedding_dim * 2**i), device=device),
                    fl.Transpose(1, 2),
                    SquareUnflatten(2),
                    fl.SetContext("swin", "outputs", callback=lambda t, x: t.insert(0, x)),
                ),
                PatchMerging(dim=int(embedding_dim * 2**i), device=device)
                if i < self.num_layers - 1
                else fl.UseContext("swin", "outputs").compose(lambda t: tuple(t)),
            )
            for i in range(self.num_layers)
        ),
    )

WindowAttention

WindowAttention(
    dim: int,
    window_size: int,
    num_heads: int,
    shift: bool = False,
    device: device | None = None,
)

Bases: Chain

Window-based Multi-head Self-Attention (W-MSA), optionally shifted (SW-MSA).

It has a trainable relative position bias (RelativePositionBias).

The input projection is stored as a single Linear for q, k and v.

Source code in src/refiners/foundationals/swin/swin_transformer.py
def __init__(
    self,
    dim: int,
    window_size: int,
    num_heads: int,
    shift: bool = False,
    device: Device | None = None,
) -> None:
    super().__init__(
        fl.Linear(dim, dim * 3, bias=True, device=device),
        WindowSDPA(window_size, num_heads, shift, device=device),
        fl.Linear(dim, dim, device=device),
    )

MVANet

MVANet(
    embedding_dim: int = 128,
    n_logits: int = 1,
    depths: list[int] | None = None,
    num_heads: list[int] | None = None,
    window_size: int = 12,
    device: device | None = None,
)

Bases: Chain

Multi-view Aggregation Network for Dichotomous Image Segmentation

See [arXiv:2404.07445] Multi-view Aggregation Network for Dichotomous Image Segmentation for more details.

Parameters:

Name Type Description Default
embedding_dim int

embedding dimension

128
n_logits int

the number of output logits (default to 1) 1 logit is used for alpha matting/foreground-background segmentation/sod segmentation

1
depths list[int] None
num_heads list[int] None
window_size int

default to 12, see SwinTransformer

12
device device | None

the device to use

None
Source code in src/refiners/foundationals/swin/mvanet/mvanet.py
def __init__(
    self,
    embedding_dim: int = 128,
    n_logits: int = 1,
    depths: list[int] | None = None,
    num_heads: list[int] | None = None,
    window_size: int = 12,
    device: Device | None = None,
) -> None:
    if depths is None:
        depths = [2, 2, 18, 2]
    if num_heads is None:
        num_heads = [4, 8, 16, 32]

    super().__init__(
        ComputeShallow(embedding_dim=embedding_dim, device=device),
        SplitMultiView(),
        fl.Flatten(0, 1),
        SwinTransformer(
            embedding_dim=embedding_dim,
            depths=depths,
            num_heads=num_heads,
            window_size=window_size,
            device=device,
        ),
        fl.Distribute(*(Unflatten(0, (-1, 5)) for _ in range(5))),
        Pyramid(embedding_dim=embedding_dim, device=device),
        RearrangeMultiView(embedding_dim=embedding_dim, device=device),
        ShallowUpscaler(embedding_dim, device=device),
        fl.Conv2d(embedding_dim, n_logits, kernel_size=3, padding=1, device=device),
    )