From 0046d2288fa8cd09ff36334d389f78f9abc179f3 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 27 Aug 2024 17:53:10 +0200 Subject: [PATCH] return typing for __init__ --- .../foundationals/swin/mvanet/mclm.py | 14 +++++------ .../foundationals/swin/mvanet/mcrm.py | 4 ++-- .../foundationals/swin/mvanet/mvanet.py | 24 +++++++++---------- .../foundationals/swin/mvanet/utils.py | 18 +++++++------- .../foundationals/swin/swin_transformer.py | 22 ++++++++--------- 5 files changed, 41 insertions(+), 41 deletions(-) diff --git a/src/refiners/foundationals/swin/mvanet/mclm.py b/src/refiners/foundationals/swin/mvanet/mclm.py index b860d06..b0a21f7 100644 --- a/src/refiners/foundationals/swin/mvanet/mclm.py +++ b/src/refiners/foundationals/swin/mvanet/mclm.py @@ -14,7 +14,7 @@ from .utils import FeedForward, MultiheadAttention, MultiPool, PatchMerge, Patch class PerPixel(fl.Chain): """(B, C, H, W) -> H*W, B, C""" - def __init__(self): + def __init__(self) -> None: super().__init__( fl.Permute(2, 3, 0, 1), fl.Flatten(0, 1), @@ -26,7 +26,7 @@ class PositionEmbeddingSine(fl.Module): Non-trainable position embedding, originally from https://github.com/facebookresearch/detr """ - def __init__(self, num_pos_feats: int, device: Device | None = None): + def __init__(self, num_pos_feats: int, device: Device | None = None) -> None: super().__init__() self.device = device temperature = 10000 @@ -51,7 +51,7 @@ class PositionEmbeddingSine(fl.Module): class MultiPoolPos(fl.Module): - def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine): + def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine) -> None: super().__init__() self.pool_ratios = pool_ratios self.positional_embedding = positional_embedding @@ -62,7 +62,7 @@ class MultiPoolPos(fl.Module): class Repeat(fl.Module): - def __init__(self, dim: int = 0): + def __init__(self, dim: int = 0) -> None: self.dim = dim super().__init__() @@ -71,7 +71,7 @@ class Repeat(fl.Module): class _MHA_Arg(fl.Sum): - def __init__(self, offset: int): + def __init__(self, offset: int) -> None: self.offset = offset super().__init__( fl.GetArg(offset), # value @@ -95,7 +95,7 @@ class GlobalAttention(fl.Chain): emb_dim: int, num_heads: int = 1, device: Device | None = None, - ): + ) -> None: super().__init__( fl.Sum( fl.GetArg(0), # global @@ -125,7 +125,7 @@ class MCLM(fl.Chain): num_heads: int = 1, pool_ratios: list[int] | None = None, device: Device | None = None, - ): + ) -> None: if pool_ratios is None: pool_ratios = [2, 8, 16] diff --git a/src/refiners/foundationals/swin/mvanet/mcrm.py b/src/refiners/foundationals/swin/mvanet/mcrm.py index 01311da..fc31a12 100644 --- a/src/refiners/foundationals/swin/mvanet/mcrm.py +++ b/src/refiners/foundationals/swin/mvanet/mcrm.py @@ -24,7 +24,7 @@ class TiledCrossAttention(fl.Chain): num_heads: int = 1, pool_ratios: list[int] | None = None, device: Device | None = None, - ): + ) -> None: # Input must be a 4-tuple: (local, global) if pool_ratios is None: @@ -70,7 +70,7 @@ class MCRM(fl.Chain): num_heads: int = 1, pool_ratios: list[int] | None = None, device: Device | None = None, - ): + ) -> None: if pool_ratios is None: pool_ratios = [1, 2, 4] diff --git a/src/refiners/foundationals/swin/mvanet/mvanet.py b/src/refiners/foundationals/swin/mvanet/mvanet.py index 4410eb6..0e314e6 100644 --- a/src/refiners/foundationals/swin/mvanet/mvanet.py +++ b/src/refiners/foundationals/swin/mvanet/mvanet.py @@ -19,7 +19,7 @@ class CBG(fl.Chain): in_dim: int, out_dim: int | None = None, device: Device | None = None, - ): + ) -> None: out_dim = out_dim or in_dim super().__init__( fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device), @@ -36,7 +36,7 @@ class CBR(fl.Chain): in_dim: int, out_dim: int | None = None, device: Device | None = None, - ): + ) -> None: out_dim = out_dim or in_dim super().__init__( fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device), @@ -57,7 +57,7 @@ class SplitMultiView(fl.Chain): multi_view (b, 5, c, H/2, W/2) """ - def __init__(self): + def __init__(self) -> None: super().__init__( fl.Concatenate( PatchSplit(), # global features @@ -88,7 +88,7 @@ class ShallowUpscaler(fl.Chain): self, embedding_dim: int = 128, device: Device | None = None, - ): + ) -> None: super().__init__( fl.Sum( fl.Identity(), @@ -117,7 +117,7 @@ class PyramidL5(fl.Chain): self, embedding_dim: int = 128, device: Device | None = None, - ): + ) -> None: super().__init__( fl.GetArg(0), # output5 fl.Flatten(0, 1), @@ -134,7 +134,7 @@ class PyramidL4(fl.Chain): self, embedding_dim: int = 128, device: Device | None = None, - ): + ) -> None: super().__init__( fl.Sum( PyramidL5(embedding_dim=embedding_dim, device=device), @@ -157,7 +157,7 @@ class PyramidL3(fl.Chain): self, embedding_dim: int = 128, device: Device | None = None, - ): + ) -> None: super().__init__( fl.Sum( PyramidL4(embedding_dim=embedding_dim, device=device), @@ -180,7 +180,7 @@ class PyramidL2(fl.Chain): self, embedding_dim: int = 128, device: Device | None = None, - ): + ) -> None: embedding_dim = 128 super().__init__( fl.Sum( @@ -219,7 +219,7 @@ class Pyramid(fl.Chain): self, embedding_dim: int = 128, device: Device | None = None, - ): + ) -> None: super().__init__( fl.Sum( PyramidL2(embedding_dim=embedding_dim, device=device), @@ -253,7 +253,7 @@ class RearrangeMultiView(fl.Chain): self, embedding_dim: int = 128, device: Device | None = None, - ): + ) -> None: super().__init__( fl.Sum( fl.Chain( # local features @@ -279,7 +279,7 @@ class ComputeShallow(fl.Passthrough): self, embedding_dim: int = 128, device: Device | None = None, - ): + ) -> None: super().__init__( fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device), fl.SetContext("mvanet", "shallow"), @@ -309,7 +309,7 @@ class MVANet(fl.Chain): num_heads: list[int] | None = None, window_size: int = 12, device: Device | None = None, - ): + ) -> None: if depths is None: depths = [2, 2, 18, 2] if num_heads is None: diff --git a/src/refiners/foundationals/swin/mvanet/utils.py b/src/refiners/foundationals/swin/mvanet/utils.py index 363b55a..d7456dd 100644 --- a/src/refiners/foundationals/swin/mvanet/utils.py +++ b/src/refiners/foundationals/swin/mvanet/utils.py @@ -19,7 +19,7 @@ class Unflatten(fl.Module): class Interpolate(fl.Module): - def __init__(self, size: tuple[int, ...], mode: str = "bilinear"): + def __init__(self, size: tuple[int, ...], mode: str = "bilinear") -> None: super().__init__() self.size = Size(size) self.mode = mode @@ -29,7 +29,7 @@ class Interpolate(fl.Module): class Rescale(fl.Module): - def __init__(self, scale_factor: float, mode: str = "nearest"): + def __init__(self, scale_factor: float, mode: str = "nearest") -> None: super().__init__() self.scale_factor = scale_factor self.mode = mode @@ -39,19 +39,19 @@ class Rescale(fl.Module): class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule): - def __init__(self, num_features: int, device: torch.device | None = None): + def __init__(self, num_features: int, device: torch.device | None = None) -> None: super().__init__(num_features=num_features, device=device) # type: ignore class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation): - def __init__(self, device: torch.device | None = None): + def __init__(self, device: torch.device | None = None) -> None: super().__init__(device=device) # type: ignore class PatchSplit(fl.Chain): """(B, N, H, W) -> B, 4, N, H/2, W/2""" - def __init__(self): + def __init__(self) -> None: super().__init__( Unflatten(-2, (2, -1)), Unflatten(-1, (2, -1)), @@ -63,7 +63,7 @@ class PatchSplit(fl.Chain): class PatchMerge(fl.Chain): """B, 4, N, H, W -> (B, N, 2*H, 2*W)""" - def __init__(self): + def __init__(self) -> None: super().__init__( Unflatten(1, (2, 2)), fl.Permute(0, 3, 1, 4, 2, 5), @@ -82,7 +82,7 @@ class FeedForward(fl.Residual): class _GetArgs(fl.Parallel): - def __init__(self, n: int): + def __init__(self, n: int) -> None: super().__init__( fl.Chain( fl.GetArg(0), @@ -103,7 +103,7 @@ class _GetArgs(fl.Parallel): class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule): - def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None): + def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None) -> None: super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore @property @@ -122,7 +122,7 @@ class PatchwiseCrossAttention(fl.Chain): d_model: int, num_heads: int, device: torch.device | None = None, - ): + ) -> None: super().__init__( fl.Concatenate( fl.Chain( diff --git a/src/refiners/foundationals/swin/swin_transformer.py b/src/refiners/foundationals/swin/swin_transformer.py index f1291fe..dd1fca4 100644 --- a/src/refiners/foundationals/swin/swin_transformer.py +++ b/src/refiners/foundationals/swin/swin_transformer.py @@ -22,7 +22,7 @@ def to_windows(x: Tensor, window_size: int) -> Tensor: class ToWindows(fl.Module): - def __init__(self, window_size: int): + def __init__(self, window_size: int) -> None: super().__init__() self.window_size = window_size @@ -67,7 +67,7 @@ def get_attn_mask(H: int, window_size: int, device: Device | None = None) -> Ten class Pad(fl.Module): - def __init__(self, step: int): + def __init__(self, step: int) -> None: super().__init__() self.step = step @@ -135,7 +135,7 @@ class WindowUnflatten(fl.Module): class Roll(fl.Module): - def __init__(self, *shifts: tuple[int, int]): + def __init__(self, *shifts: tuple[int, int]) -> None: super().__init__() self.shifts = shifts self._dims = tuple(s[0] for s in shifts) @@ -148,7 +148,7 @@ class Roll(fl.Module): class RelativePositionBias(fl.Module): relative_position_index: Tensor - def __init__(self, window_size: int, num_heads: int, device: Device | None = None): + def __init__(self, window_size: int, num_heads: int, device: Device | None = None) -> None: super().__init__() self.relative_position_bias_table = torch.nn.Parameter( torch.empty( @@ -178,7 +178,7 @@ class WindowSDPA(fl.Module): num_heads: int, shift: bool = False, device: Device | None = None, - ): + ) -> None: super().__init__() self.window_size = window_size self.num_heads = num_heads @@ -220,7 +220,7 @@ class WindowAttention(fl.Chain): num_heads: int, shift: bool = False, device: Device | None = None, - ): + ) -> None: super().__init__( fl.Linear(dim, dim * 3, bias=True, device=device), WindowSDPA(dim, window_size, num_heads, shift, device=device), @@ -237,7 +237,7 @@ class SwinTransformerBlock(fl.Chain): shift_size: int = 0, mlp_ratio: float = 4.0, device: Device | None = None, - ): + ) -> None: assert 0 <= shift_size < window_size, "shift_size must in [0, window_size[" super().__init__( @@ -272,7 +272,7 @@ class SwinTransformerBlock(fl.Chain): class PatchMerging(fl.Chain): - def __init__(self, dim: int, device: Device | None = None): + def __init__(self, dim: int, device: Device | None = None) -> None: super().__init__( SquareUnflatten(1), Pad(2), @@ -295,7 +295,7 @@ class BasicLayer(fl.Chain): window_size: int = 7, mlp_ratio: float = 4.0, device: Device | None = None, - ): + ) -> None: super().__init__( SwinTransformerBlock( dim=dim, @@ -316,7 +316,7 @@ class PatchEmbedding(fl.Chain): in_chans: int = 3, embedding_dim: int = 96, device: Device | None = None, - ): + ) -> None: super().__init__( fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device), fl.Flatten(2), @@ -341,7 +341,7 @@ class SwinTransformer(fl.Chain): window_size: int = 7, # image size is 32 * this mlp_ratio: float = 4.0, device: Device | None = None, - ): + ) -> None: if depths is None: depths = [2, 2, 6, 2]