mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 14:18:46 +00:00
return typing for __init__
This commit is contained in:
parent
8aa1d9d91d
commit
5a1e1e70f8
|
@ -14,7 +14,7 @@ from .utils import FeedForward, MultiheadAttention, MultiPool, PatchMerge, Patch
|
|||
class PerPixel(fl.Chain):
|
||||
"""(B, C, H, W) -> H*W, B, C"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
fl.Permute(2, 3, 0, 1),
|
||||
fl.Flatten(0, 1),
|
||||
|
@ -26,7 +26,7 @@ class PositionEmbeddingSine(fl.Module):
|
|||
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats: int, device: Device | None = None):
|
||||
def __init__(self, num_pos_feats: int, device: Device | None = None) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
temperature = 10000
|
||||
|
@ -51,7 +51,7 @@ class PositionEmbeddingSine(fl.Module):
|
|||
|
||||
|
||||
class MultiPoolPos(fl.Module):
|
||||
def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine):
|
||||
def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine) -> None:
|
||||
super().__init__()
|
||||
self.pool_ratios = pool_ratios
|
||||
self.positional_embedding = positional_embedding
|
||||
|
@ -62,7 +62,7 @@ class MultiPoolPos(fl.Module):
|
|||
|
||||
|
||||
class Repeat(fl.Module):
|
||||
def __init__(self, dim: int = 0):
|
||||
def __init__(self, dim: int = 0) -> None:
|
||||
self.dim = dim
|
||||
super().__init__()
|
||||
|
||||
|
@ -71,7 +71,7 @@ class Repeat(fl.Module):
|
|||
|
||||
|
||||
class _MHA_Arg(fl.Sum):
|
||||
def __init__(self, offset: int):
|
||||
def __init__(self, offset: int) -> None:
|
||||
self.offset = offset
|
||||
super().__init__(
|
||||
fl.GetArg(offset), # value
|
||||
|
@ -95,7 +95,7 @@ class GlobalAttention(fl.Chain):
|
|||
emb_dim: int,
|
||||
num_heads: int = 1,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
fl.GetArg(0), # global
|
||||
|
@ -125,7 +125,7 @@ class MCLM(fl.Chain):
|
|||
num_heads: int = 1,
|
||||
pool_ratios: list[int] | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
if pool_ratios is None:
|
||||
pool_ratios = [2, 8, 16]
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ class TiledCrossAttention(fl.Chain):
|
|||
num_heads: int = 1,
|
||||
pool_ratios: list[int] | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
# Input must be a 4-tuple: (local, global)
|
||||
|
||||
if pool_ratios is None:
|
||||
|
@ -70,7 +70,7 @@ class MCRM(fl.Chain):
|
|||
num_heads: int = 1,
|
||||
pool_ratios: list[int] | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
if pool_ratios is None:
|
||||
pool_ratios = [1, 2, 4]
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ class CBG(fl.Chain):
|
|||
in_dim: int,
|
||||
out_dim: int | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
out_dim = out_dim or in_dim
|
||||
super().__init__(
|
||||
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
|
||||
|
@ -36,7 +36,7 @@ class CBR(fl.Chain):
|
|||
in_dim: int,
|
||||
out_dim: int | None = None,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
out_dim = out_dim or in_dim
|
||||
super().__init__(
|
||||
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
|
||||
|
@ -57,7 +57,7 @@ class SplitMultiView(fl.Chain):
|
|||
multi_view (b, 5, c, H/2, W/2)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
fl.Concatenate(
|
||||
PatchSplit(), # global features
|
||||
|
@ -88,7 +88,7 @@ class ShallowUpscaler(fl.Chain):
|
|||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
fl.Identity(),
|
||||
|
@ -117,7 +117,7 @@ class PyramidL5(fl.Chain):
|
|||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.GetArg(0), # output5
|
||||
fl.Flatten(0, 1),
|
||||
|
@ -134,7 +134,7 @@ class PyramidL4(fl.Chain):
|
|||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
PyramidL5(embedding_dim=embedding_dim, device=device),
|
||||
|
@ -157,7 +157,7 @@ class PyramidL3(fl.Chain):
|
|||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
PyramidL4(embedding_dim=embedding_dim, device=device),
|
||||
|
@ -180,7 +180,7 @@ class PyramidL2(fl.Chain):
|
|||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
embedding_dim = 128
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
|
@ -219,7 +219,7 @@ class Pyramid(fl.Chain):
|
|||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
PyramidL2(embedding_dim=embedding_dim, device=device),
|
||||
|
@ -253,7 +253,7 @@ class RearrangeMultiView(fl.Chain):
|
|||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Sum(
|
||||
fl.Chain( # local features
|
||||
|
@ -279,7 +279,7 @@ class ComputeShallow(fl.Passthrough):
|
|||
self,
|
||||
embedding_dim: int = 128,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device),
|
||||
fl.SetContext("mvanet", "shallow"),
|
||||
|
@ -309,7 +309,7 @@ class MVANet(fl.Chain):
|
|||
num_heads: list[int] | None = None,
|
||||
window_size: int = 12,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
if depths is None:
|
||||
depths = [2, 2, 18, 2]
|
||||
if num_heads is None:
|
||||
|
|
|
@ -19,7 +19,7 @@ class Unflatten(fl.Module):
|
|||
|
||||
|
||||
class Interpolate(fl.Module):
|
||||
def __init__(self, size: tuple[int, ...], mode: str = "bilinear"):
|
||||
def __init__(self, size: tuple[int, ...], mode: str = "bilinear") -> None:
|
||||
super().__init__()
|
||||
self.size = Size(size)
|
||||
self.mode = mode
|
||||
|
@ -29,7 +29,7 @@ class Interpolate(fl.Module):
|
|||
|
||||
|
||||
class Rescale(fl.Module):
|
||||
def __init__(self, scale_factor: float, mode: str = "nearest"):
|
||||
def __init__(self, scale_factor: float, mode: str = "nearest") -> None:
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
|
@ -39,19 +39,19 @@ class Rescale(fl.Module):
|
|||
|
||||
|
||||
class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule):
|
||||
def __init__(self, num_features: int, device: torch.device | None = None):
|
||||
def __init__(self, num_features: int, device: torch.device | None = None) -> None:
|
||||
super().__init__(num_features=num_features, device=device) # type: ignore
|
||||
|
||||
|
||||
class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation):
|
||||
def __init__(self, device: torch.device | None = None):
|
||||
def __init__(self, device: torch.device | None = None) -> None:
|
||||
super().__init__(device=device) # type: ignore
|
||||
|
||||
|
||||
class PatchSplit(fl.Chain):
|
||||
"""(B, N, H, W) -> B, 4, N, H/2, W/2"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
Unflatten(-2, (2, -1)),
|
||||
Unflatten(-1, (2, -1)),
|
||||
|
@ -63,7 +63,7 @@ class PatchSplit(fl.Chain):
|
|||
class PatchMerge(fl.Chain):
|
||||
"""B, 4, N, H, W -> (B, N, 2*H, 2*W)"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
Unflatten(1, (2, 2)),
|
||||
fl.Permute(0, 3, 1, 4, 2, 5),
|
||||
|
@ -82,7 +82,7 @@ class FeedForward(fl.Residual):
|
|||
|
||||
|
||||
class _GetArgs(fl.Parallel):
|
||||
def __init__(self, n: int):
|
||||
def __init__(self, n: int) -> None:
|
||||
super().__init__(
|
||||
fl.Chain(
|
||||
fl.GetArg(0),
|
||||
|
@ -103,7 +103,7 @@ class _GetArgs(fl.Parallel):
|
|||
|
||||
|
||||
class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule):
|
||||
def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None):
|
||||
def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None) -> None:
|
||||
super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore
|
||||
|
||||
@property
|
||||
|
@ -122,7 +122,7 @@ class PatchwiseCrossAttention(fl.Chain):
|
|||
d_model: int,
|
||||
num_heads: int,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Concatenate(
|
||||
fl.Chain(
|
||||
|
|
|
@ -22,7 +22,7 @@ def to_windows(x: Tensor, window_size: int) -> Tensor:
|
|||
|
||||
|
||||
class ToWindows(fl.Module):
|
||||
def __init__(self, window_size: int):
|
||||
def __init__(self, window_size: int) -> None:
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
|
||||
|
@ -67,7 +67,7 @@ def get_attn_mask(H: int, window_size: int, device: Device | None = None) -> Ten
|
|||
|
||||
|
||||
class Pad(fl.Module):
|
||||
def __init__(self, step: int):
|
||||
def __init__(self, step: int) -> None:
|
||||
super().__init__()
|
||||
self.step = step
|
||||
|
||||
|
@ -135,7 +135,7 @@ class WindowUnflatten(fl.Module):
|
|||
|
||||
|
||||
class Roll(fl.Module):
|
||||
def __init__(self, *shifts: tuple[int, int]):
|
||||
def __init__(self, *shifts: tuple[int, int]) -> None:
|
||||
super().__init__()
|
||||
self.shifts = shifts
|
||||
self._dims = tuple(s[0] for s in shifts)
|
||||
|
@ -148,7 +148,7 @@ class Roll(fl.Module):
|
|||
class RelativePositionBias(fl.Module):
|
||||
relative_position_index: Tensor
|
||||
|
||||
def __init__(self, window_size: int, num_heads: int, device: Device | None = None):
|
||||
def __init__(self, window_size: int, num_heads: int, device: Device | None = None) -> None:
|
||||
super().__init__()
|
||||
self.relative_position_bias_table = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
|
@ -178,7 +178,7 @@ class WindowSDPA(fl.Module):
|
|||
num_heads: int,
|
||||
shift: bool = False,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_heads = num_heads
|
||||
|
@ -220,7 +220,7 @@ class WindowAttention(fl.Chain):
|
|||
num_heads: int,
|
||||
shift: bool = False,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Linear(dim, dim * 3, bias=True, device=device),
|
||||
WindowSDPA(dim, window_size, num_heads, shift, device=device),
|
||||
|
@ -237,7 +237,7 @@ class SwinTransformerBlock(fl.Chain):
|
|||
shift_size: int = 0,
|
||||
mlp_ratio: float = 4.0,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
assert 0 <= shift_size < window_size, "shift_size must in [0, window_size["
|
||||
|
||||
super().__init__(
|
||||
|
@ -272,7 +272,7 @@ class SwinTransformerBlock(fl.Chain):
|
|||
|
||||
|
||||
class PatchMerging(fl.Chain):
|
||||
def __init__(self, dim: int, device: Device | None = None):
|
||||
def __init__(self, dim: int, device: Device | None = None) -> None:
|
||||
super().__init__(
|
||||
SquareUnflatten(1),
|
||||
Pad(2),
|
||||
|
@ -295,7 +295,7 @@ class BasicLayer(fl.Chain):
|
|||
window_size: int = 7,
|
||||
mlp_ratio: float = 4.0,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
|
@ -316,7 +316,7 @@ class PatchEmbedding(fl.Chain):
|
|||
in_chans: int = 3,
|
||||
embedding_dim: int = 96,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device),
|
||||
fl.Flatten(2),
|
||||
|
@ -341,7 +341,7 @@ class SwinTransformer(fl.Chain):
|
|||
window_size: int = 7, # image size is 32 * this
|
||||
mlp_ratio: float = 4.0,
|
||||
device: Device | None = None,
|
||||
):
|
||||
) -> None:
|
||||
if depths is None:
|
||||
depths = [2, 2, 6, 2]
|
||||
|
||||
|
|
Loading…
Reference in a new issue