return typing for __init__

This commit is contained in:
Pierre Chapuis 2024-08-27 17:53:10 +02:00
parent 8aa1d9d91d
commit 0046d2288f
5 changed files with 41 additions and 41 deletions

View file

@ -14,7 +14,7 @@ from .utils import FeedForward, MultiheadAttention, MultiPool, PatchMerge, Patch
class PerPixel(fl.Chain): class PerPixel(fl.Chain):
"""(B, C, H, W) -> H*W, B, C""" """(B, C, H, W) -> H*W, B, C"""
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
fl.Permute(2, 3, 0, 1), fl.Permute(2, 3, 0, 1),
fl.Flatten(0, 1), fl.Flatten(0, 1),
@ -26,7 +26,7 @@ class PositionEmbeddingSine(fl.Module):
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
""" """
def __init__(self, num_pos_feats: int, device: Device | None = None): def __init__(self, num_pos_feats: int, device: Device | None = None) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
temperature = 10000 temperature = 10000
@ -51,7 +51,7 @@ class PositionEmbeddingSine(fl.Module):
class MultiPoolPos(fl.Module): class MultiPoolPos(fl.Module):
def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine): def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine) -> None:
super().__init__() super().__init__()
self.pool_ratios = pool_ratios self.pool_ratios = pool_ratios
self.positional_embedding = positional_embedding self.positional_embedding = positional_embedding
@ -62,7 +62,7 @@ class MultiPoolPos(fl.Module):
class Repeat(fl.Module): class Repeat(fl.Module):
def __init__(self, dim: int = 0): def __init__(self, dim: int = 0) -> None:
self.dim = dim self.dim = dim
super().__init__() super().__init__()
@ -71,7 +71,7 @@ class Repeat(fl.Module):
class _MHA_Arg(fl.Sum): class _MHA_Arg(fl.Sum):
def __init__(self, offset: int): def __init__(self, offset: int) -> None:
self.offset = offset self.offset = offset
super().__init__( super().__init__(
fl.GetArg(offset), # value fl.GetArg(offset), # value
@ -95,7 +95,7 @@ class GlobalAttention(fl.Chain):
emb_dim: int, emb_dim: int,
num_heads: int = 1, num_heads: int = 1,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Sum( fl.Sum(
fl.GetArg(0), # global fl.GetArg(0), # global
@ -125,7 +125,7 @@ class MCLM(fl.Chain):
num_heads: int = 1, num_heads: int = 1,
pool_ratios: list[int] | None = None, pool_ratios: list[int] | None = None,
device: Device | None = None, device: Device | None = None,
): ) -> None:
if pool_ratios is None: if pool_ratios is None:
pool_ratios = [2, 8, 16] pool_ratios = [2, 8, 16]

View file

@ -24,7 +24,7 @@ class TiledCrossAttention(fl.Chain):
num_heads: int = 1, num_heads: int = 1,
pool_ratios: list[int] | None = None, pool_ratios: list[int] | None = None,
device: Device | None = None, device: Device | None = None,
): ) -> None:
# Input must be a 4-tuple: (local, global) # Input must be a 4-tuple: (local, global)
if pool_ratios is None: if pool_ratios is None:
@ -70,7 +70,7 @@ class MCRM(fl.Chain):
num_heads: int = 1, num_heads: int = 1,
pool_ratios: list[int] | None = None, pool_ratios: list[int] | None = None,
device: Device | None = None, device: Device | None = None,
): ) -> None:
if pool_ratios is None: if pool_ratios is None:
pool_ratios = [1, 2, 4] pool_ratios = [1, 2, 4]

View file

@ -19,7 +19,7 @@ class CBG(fl.Chain):
in_dim: int, in_dim: int,
out_dim: int | None = None, out_dim: int | None = None,
device: Device | None = None, device: Device | None = None,
): ) -> None:
out_dim = out_dim or in_dim out_dim = out_dim or in_dim
super().__init__( super().__init__(
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device), fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
@ -36,7 +36,7 @@ class CBR(fl.Chain):
in_dim: int, in_dim: int,
out_dim: int | None = None, out_dim: int | None = None,
device: Device | None = None, device: Device | None = None,
): ) -> None:
out_dim = out_dim or in_dim out_dim = out_dim or in_dim
super().__init__( super().__init__(
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device), fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
@ -57,7 +57,7 @@ class SplitMultiView(fl.Chain):
multi_view (b, 5, c, H/2, W/2) multi_view (b, 5, c, H/2, W/2)
""" """
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
fl.Concatenate( fl.Concatenate(
PatchSplit(), # global features PatchSplit(), # global features
@ -88,7 +88,7 @@ class ShallowUpscaler(fl.Chain):
self, self,
embedding_dim: int = 128, embedding_dim: int = 128,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Sum( fl.Sum(
fl.Identity(), fl.Identity(),
@ -117,7 +117,7 @@ class PyramidL5(fl.Chain):
self, self,
embedding_dim: int = 128, embedding_dim: int = 128,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.GetArg(0), # output5 fl.GetArg(0), # output5
fl.Flatten(0, 1), fl.Flatten(0, 1),
@ -134,7 +134,7 @@ class PyramidL4(fl.Chain):
self, self,
embedding_dim: int = 128, embedding_dim: int = 128,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Sum( fl.Sum(
PyramidL5(embedding_dim=embedding_dim, device=device), PyramidL5(embedding_dim=embedding_dim, device=device),
@ -157,7 +157,7 @@ class PyramidL3(fl.Chain):
self, self,
embedding_dim: int = 128, embedding_dim: int = 128,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Sum( fl.Sum(
PyramidL4(embedding_dim=embedding_dim, device=device), PyramidL4(embedding_dim=embedding_dim, device=device),
@ -180,7 +180,7 @@ class PyramidL2(fl.Chain):
self, self,
embedding_dim: int = 128, embedding_dim: int = 128,
device: Device | None = None, device: Device | None = None,
): ) -> None:
embedding_dim = 128 embedding_dim = 128
super().__init__( super().__init__(
fl.Sum( fl.Sum(
@ -219,7 +219,7 @@ class Pyramid(fl.Chain):
self, self,
embedding_dim: int = 128, embedding_dim: int = 128,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Sum( fl.Sum(
PyramidL2(embedding_dim=embedding_dim, device=device), PyramidL2(embedding_dim=embedding_dim, device=device),
@ -253,7 +253,7 @@ class RearrangeMultiView(fl.Chain):
self, self,
embedding_dim: int = 128, embedding_dim: int = 128,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Sum( fl.Sum(
fl.Chain( # local features fl.Chain( # local features
@ -279,7 +279,7 @@ class ComputeShallow(fl.Passthrough):
self, self,
embedding_dim: int = 128, embedding_dim: int = 128,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device), fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device),
fl.SetContext("mvanet", "shallow"), fl.SetContext("mvanet", "shallow"),
@ -309,7 +309,7 @@ class MVANet(fl.Chain):
num_heads: list[int] | None = None, num_heads: list[int] | None = None,
window_size: int = 12, window_size: int = 12,
device: Device | None = None, device: Device | None = None,
): ) -> None:
if depths is None: if depths is None:
depths = [2, 2, 18, 2] depths = [2, 2, 18, 2]
if num_heads is None: if num_heads is None:

View file

@ -19,7 +19,7 @@ class Unflatten(fl.Module):
class Interpolate(fl.Module): class Interpolate(fl.Module):
def __init__(self, size: tuple[int, ...], mode: str = "bilinear"): def __init__(self, size: tuple[int, ...], mode: str = "bilinear") -> None:
super().__init__() super().__init__()
self.size = Size(size) self.size = Size(size)
self.mode = mode self.mode = mode
@ -29,7 +29,7 @@ class Interpolate(fl.Module):
class Rescale(fl.Module): class Rescale(fl.Module):
def __init__(self, scale_factor: float, mode: str = "nearest"): def __init__(self, scale_factor: float, mode: str = "nearest") -> None:
super().__init__() super().__init__()
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.mode = mode self.mode = mode
@ -39,19 +39,19 @@ class Rescale(fl.Module):
class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule): class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule):
def __init__(self, num_features: int, device: torch.device | None = None): def __init__(self, num_features: int, device: torch.device | None = None) -> None:
super().__init__(num_features=num_features, device=device) # type: ignore super().__init__(num_features=num_features, device=device) # type: ignore
class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation): class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation):
def __init__(self, device: torch.device | None = None): def __init__(self, device: torch.device | None = None) -> None:
super().__init__(device=device) # type: ignore super().__init__(device=device) # type: ignore
class PatchSplit(fl.Chain): class PatchSplit(fl.Chain):
"""(B, N, H, W) -> B, 4, N, H/2, W/2""" """(B, N, H, W) -> B, 4, N, H/2, W/2"""
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
Unflatten(-2, (2, -1)), Unflatten(-2, (2, -1)),
Unflatten(-1, (2, -1)), Unflatten(-1, (2, -1)),
@ -63,7 +63,7 @@ class PatchSplit(fl.Chain):
class PatchMerge(fl.Chain): class PatchMerge(fl.Chain):
"""B, 4, N, H, W -> (B, N, 2*H, 2*W)""" """B, 4, N, H, W -> (B, N, 2*H, 2*W)"""
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
Unflatten(1, (2, 2)), Unflatten(1, (2, 2)),
fl.Permute(0, 3, 1, 4, 2, 5), fl.Permute(0, 3, 1, 4, 2, 5),
@ -82,7 +82,7 @@ class FeedForward(fl.Residual):
class _GetArgs(fl.Parallel): class _GetArgs(fl.Parallel):
def __init__(self, n: int): def __init__(self, n: int) -> None:
super().__init__( super().__init__(
fl.Chain( fl.Chain(
fl.GetArg(0), fl.GetArg(0),
@ -103,7 +103,7 @@ class _GetArgs(fl.Parallel):
class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule): class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule):
def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None): def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None) -> None:
super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore
@property @property
@ -122,7 +122,7 @@ class PatchwiseCrossAttention(fl.Chain):
d_model: int, d_model: int,
num_heads: int, num_heads: int,
device: torch.device | None = None, device: torch.device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Concatenate( fl.Concatenate(
fl.Chain( fl.Chain(

View file

@ -22,7 +22,7 @@ def to_windows(x: Tensor, window_size: int) -> Tensor:
class ToWindows(fl.Module): class ToWindows(fl.Module):
def __init__(self, window_size: int): def __init__(self, window_size: int) -> None:
super().__init__() super().__init__()
self.window_size = window_size self.window_size = window_size
@ -67,7 +67,7 @@ def get_attn_mask(H: int, window_size: int, device: Device | None = None) -> Ten
class Pad(fl.Module): class Pad(fl.Module):
def __init__(self, step: int): def __init__(self, step: int) -> None:
super().__init__() super().__init__()
self.step = step self.step = step
@ -135,7 +135,7 @@ class WindowUnflatten(fl.Module):
class Roll(fl.Module): class Roll(fl.Module):
def __init__(self, *shifts: tuple[int, int]): def __init__(self, *shifts: tuple[int, int]) -> None:
super().__init__() super().__init__()
self.shifts = shifts self.shifts = shifts
self._dims = tuple(s[0] for s in shifts) self._dims = tuple(s[0] for s in shifts)
@ -148,7 +148,7 @@ class Roll(fl.Module):
class RelativePositionBias(fl.Module): class RelativePositionBias(fl.Module):
relative_position_index: Tensor relative_position_index: Tensor
def __init__(self, window_size: int, num_heads: int, device: Device | None = None): def __init__(self, window_size: int, num_heads: int, device: Device | None = None) -> None:
super().__init__() super().__init__()
self.relative_position_bias_table = torch.nn.Parameter( self.relative_position_bias_table = torch.nn.Parameter(
torch.empty( torch.empty(
@ -178,7 +178,7 @@ class WindowSDPA(fl.Module):
num_heads: int, num_heads: int,
shift: bool = False, shift: bool = False,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__() super().__init__()
self.window_size = window_size self.window_size = window_size
self.num_heads = num_heads self.num_heads = num_heads
@ -220,7 +220,7 @@ class WindowAttention(fl.Chain):
num_heads: int, num_heads: int,
shift: bool = False, shift: bool = False,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Linear(dim, dim * 3, bias=True, device=device), fl.Linear(dim, dim * 3, bias=True, device=device),
WindowSDPA(dim, window_size, num_heads, shift, device=device), WindowSDPA(dim, window_size, num_heads, shift, device=device),
@ -237,7 +237,7 @@ class SwinTransformerBlock(fl.Chain):
shift_size: int = 0, shift_size: int = 0,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
device: Device | None = None, device: Device | None = None,
): ) -> None:
assert 0 <= shift_size < window_size, "shift_size must in [0, window_size[" assert 0 <= shift_size < window_size, "shift_size must in [0, window_size["
super().__init__( super().__init__(
@ -272,7 +272,7 @@ class SwinTransformerBlock(fl.Chain):
class PatchMerging(fl.Chain): class PatchMerging(fl.Chain):
def __init__(self, dim: int, device: Device | None = None): def __init__(self, dim: int, device: Device | None = None) -> None:
super().__init__( super().__init__(
SquareUnflatten(1), SquareUnflatten(1),
Pad(2), Pad(2),
@ -295,7 +295,7 @@ class BasicLayer(fl.Chain):
window_size: int = 7, window_size: int = 7,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
SwinTransformerBlock( SwinTransformerBlock(
dim=dim, dim=dim,
@ -316,7 +316,7 @@ class PatchEmbedding(fl.Chain):
in_chans: int = 3, in_chans: int = 3,
embedding_dim: int = 96, embedding_dim: int = 96,
device: Device | None = None, device: Device | None = None,
): ) -> None:
super().__init__( super().__init__(
fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device), fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device),
fl.Flatten(2), fl.Flatten(2),
@ -341,7 +341,7 @@ class SwinTransformer(fl.Chain):
window_size: int = 7, # image size is 32 * this window_size: int = 7, # image size is 32 * this
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
device: Device | None = None, device: Device | None = None,
): ) -> None:
if depths is None: if depths is None:
depths = [2, 2, 6, 2] depths = [2, 2, 6, 2]