mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
return typing for __init__
This commit is contained in:
parent
8aa1d9d91d
commit
0046d2288f
|
@ -14,7 +14,7 @@ from .utils import FeedForward, MultiheadAttention, MultiPool, PatchMerge, Patch
|
||||||
class PerPixel(fl.Chain):
|
class PerPixel(fl.Chain):
|
||||||
"""(B, C, H, W) -> H*W, B, C"""
|
"""(B, C, H, W) -> H*W, B, C"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Permute(2, 3, 0, 1),
|
fl.Permute(2, 3, 0, 1),
|
||||||
fl.Flatten(0, 1),
|
fl.Flatten(0, 1),
|
||||||
|
@ -26,7 +26,7 @@ class PositionEmbeddingSine(fl.Module):
|
||||||
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
|
Non-trainable position embedding, originally from https://github.com/facebookresearch/detr
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_pos_feats: int, device: Device | None = None):
|
def __init__(self, num_pos_feats: int, device: Device | None = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
temperature = 10000
|
temperature = 10000
|
||||||
|
@ -51,7 +51,7 @@ class PositionEmbeddingSine(fl.Module):
|
||||||
|
|
||||||
|
|
||||||
class MultiPoolPos(fl.Module):
|
class MultiPoolPos(fl.Module):
|
||||||
def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine):
|
def __init__(self, pool_ratios: list[int], positional_embedding: PositionEmbeddingSine) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pool_ratios = pool_ratios
|
self.pool_ratios = pool_ratios
|
||||||
self.positional_embedding = positional_embedding
|
self.positional_embedding = positional_embedding
|
||||||
|
@ -62,7 +62,7 @@ class MultiPoolPos(fl.Module):
|
||||||
|
|
||||||
|
|
||||||
class Repeat(fl.Module):
|
class Repeat(fl.Module):
|
||||||
def __init__(self, dim: int = 0):
|
def __init__(self, dim: int = 0) -> None:
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ class Repeat(fl.Module):
|
||||||
|
|
||||||
|
|
||||||
class _MHA_Arg(fl.Sum):
|
class _MHA_Arg(fl.Sum):
|
||||||
def __init__(self, offset: int):
|
def __init__(self, offset: int) -> None:
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.GetArg(offset), # value
|
fl.GetArg(offset), # value
|
||||||
|
@ -95,7 +95,7 @@ class GlobalAttention(fl.Chain):
|
||||||
emb_dim: int,
|
emb_dim: int,
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
fl.GetArg(0), # global
|
fl.GetArg(0), # global
|
||||||
|
@ -125,7 +125,7 @@ class MCLM(fl.Chain):
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
pool_ratios: list[int] | None = None,
|
pool_ratios: list[int] | None = None,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
if pool_ratios is None:
|
if pool_ratios is None:
|
||||||
pool_ratios = [2, 8, 16]
|
pool_ratios = [2, 8, 16]
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ class TiledCrossAttention(fl.Chain):
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
pool_ratios: list[int] | None = None,
|
pool_ratios: list[int] | None = None,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
# Input must be a 4-tuple: (local, global)
|
# Input must be a 4-tuple: (local, global)
|
||||||
|
|
||||||
if pool_ratios is None:
|
if pool_ratios is None:
|
||||||
|
@ -70,7 +70,7 @@ class MCRM(fl.Chain):
|
||||||
num_heads: int = 1,
|
num_heads: int = 1,
|
||||||
pool_ratios: list[int] | None = None,
|
pool_ratios: list[int] | None = None,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
if pool_ratios is None:
|
if pool_ratios is None:
|
||||||
pool_ratios = [1, 2, 4]
|
pool_ratios = [1, 2, 4]
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ class CBG(fl.Chain):
|
||||||
in_dim: int,
|
in_dim: int,
|
||||||
out_dim: int | None = None,
|
out_dim: int | None = None,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
out_dim = out_dim or in_dim
|
out_dim = out_dim or in_dim
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
|
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
|
||||||
|
@ -36,7 +36,7 @@ class CBR(fl.Chain):
|
||||||
in_dim: int,
|
in_dim: int,
|
||||||
out_dim: int | None = None,
|
out_dim: int | None = None,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
out_dim = out_dim or in_dim
|
out_dim = out_dim or in_dim
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
|
fl.Conv2d(in_dim, out_dim, kernel_size=3, padding=1, device=device),
|
||||||
|
@ -57,7 +57,7 @@ class SplitMultiView(fl.Chain):
|
||||||
multi_view (b, 5, c, H/2, W/2)
|
multi_view (b, 5, c, H/2, W/2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Concatenate(
|
fl.Concatenate(
|
||||||
PatchSplit(), # global features
|
PatchSplit(), # global features
|
||||||
|
@ -88,7 +88,7 @@ class ShallowUpscaler(fl.Chain):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 128,
|
embedding_dim: int = 128,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
fl.Identity(),
|
fl.Identity(),
|
||||||
|
@ -117,7 +117,7 @@ class PyramidL5(fl.Chain):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 128,
|
embedding_dim: int = 128,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.GetArg(0), # output5
|
fl.GetArg(0), # output5
|
||||||
fl.Flatten(0, 1),
|
fl.Flatten(0, 1),
|
||||||
|
@ -134,7 +134,7 @@ class PyramidL4(fl.Chain):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 128,
|
embedding_dim: int = 128,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
PyramidL5(embedding_dim=embedding_dim, device=device),
|
PyramidL5(embedding_dim=embedding_dim, device=device),
|
||||||
|
@ -157,7 +157,7 @@ class PyramidL3(fl.Chain):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 128,
|
embedding_dim: int = 128,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
PyramidL4(embedding_dim=embedding_dim, device=device),
|
PyramidL4(embedding_dim=embedding_dim, device=device),
|
||||||
|
@ -180,7 +180,7 @@ class PyramidL2(fl.Chain):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 128,
|
embedding_dim: int = 128,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
embedding_dim = 128
|
embedding_dim = 128
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
|
@ -219,7 +219,7 @@ class Pyramid(fl.Chain):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 128,
|
embedding_dim: int = 128,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
PyramidL2(embedding_dim=embedding_dim, device=device),
|
PyramidL2(embedding_dim=embedding_dim, device=device),
|
||||||
|
@ -253,7 +253,7 @@ class RearrangeMultiView(fl.Chain):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 128,
|
embedding_dim: int = 128,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
fl.Chain( # local features
|
fl.Chain( # local features
|
||||||
|
@ -279,7 +279,7 @@ class ComputeShallow(fl.Passthrough):
|
||||||
self,
|
self,
|
||||||
embedding_dim: int = 128,
|
embedding_dim: int = 128,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device),
|
fl.Conv2d(3, embedding_dim, kernel_size=3, padding=1, device=device),
|
||||||
fl.SetContext("mvanet", "shallow"),
|
fl.SetContext("mvanet", "shallow"),
|
||||||
|
@ -309,7 +309,7 @@ class MVANet(fl.Chain):
|
||||||
num_heads: list[int] | None = None,
|
num_heads: list[int] | None = None,
|
||||||
window_size: int = 12,
|
window_size: int = 12,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
if depths is None:
|
if depths is None:
|
||||||
depths = [2, 2, 18, 2]
|
depths = [2, 2, 18, 2]
|
||||||
if num_heads is None:
|
if num_heads is None:
|
||||||
|
|
|
@ -19,7 +19,7 @@ class Unflatten(fl.Module):
|
||||||
|
|
||||||
|
|
||||||
class Interpolate(fl.Module):
|
class Interpolate(fl.Module):
|
||||||
def __init__(self, size: tuple[int, ...], mode: str = "bilinear"):
|
def __init__(self, size: tuple[int, ...], mode: str = "bilinear") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.size = Size(size)
|
self.size = Size(size)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
@ -29,7 +29,7 @@ class Interpolate(fl.Module):
|
||||||
|
|
||||||
|
|
||||||
class Rescale(fl.Module):
|
class Rescale(fl.Module):
|
||||||
def __init__(self, scale_factor: float, mode: str = "nearest"):
|
def __init__(self, scale_factor: float, mode: str = "nearest") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
@ -39,19 +39,19 @@ class Rescale(fl.Module):
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule):
|
class BatchNorm2d(torch.nn.BatchNorm2d, fl.WeightedModule):
|
||||||
def __init__(self, num_features: int, device: torch.device | None = None):
|
def __init__(self, num_features: int, device: torch.device | None = None) -> None:
|
||||||
super().__init__(num_features=num_features, device=device) # type: ignore
|
super().__init__(num_features=num_features, device=device) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation):
|
class PReLU(torch.nn.PReLU, fl.WeightedModule, fl.Activation):
|
||||||
def __init__(self, device: torch.device | None = None):
|
def __init__(self, device: torch.device | None = None) -> None:
|
||||||
super().__init__(device=device) # type: ignore
|
super().__init__(device=device) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class PatchSplit(fl.Chain):
|
class PatchSplit(fl.Chain):
|
||||||
"""(B, N, H, W) -> B, 4, N, H/2, W/2"""
|
"""(B, N, H, W) -> B, 4, N, H/2, W/2"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Unflatten(-2, (2, -1)),
|
Unflatten(-2, (2, -1)),
|
||||||
Unflatten(-1, (2, -1)),
|
Unflatten(-1, (2, -1)),
|
||||||
|
@ -63,7 +63,7 @@ class PatchSplit(fl.Chain):
|
||||||
class PatchMerge(fl.Chain):
|
class PatchMerge(fl.Chain):
|
||||||
"""B, 4, N, H, W -> (B, N, 2*H, 2*W)"""
|
"""B, 4, N, H, W -> (B, N, 2*H, 2*W)"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Unflatten(1, (2, 2)),
|
Unflatten(1, (2, 2)),
|
||||||
fl.Permute(0, 3, 1, 4, 2, 5),
|
fl.Permute(0, 3, 1, 4, 2, 5),
|
||||||
|
@ -82,7 +82,7 @@ class FeedForward(fl.Residual):
|
||||||
|
|
||||||
|
|
||||||
class _GetArgs(fl.Parallel):
|
class _GetArgs(fl.Parallel):
|
||||||
def __init__(self, n: int):
|
def __init__(self, n: int) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.GetArg(0),
|
fl.GetArg(0),
|
||||||
|
@ -103,7 +103,7 @@ class _GetArgs(fl.Parallel):
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule):
|
class MultiheadAttention(torch.nn.MultiheadAttention, fl.WeightedModule):
|
||||||
def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None):
|
def __init__(self, embedding_dim: int, num_heads: int, device: torch.device | None = None) -> None:
|
||||||
super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore
|
super().__init__(embed_dim=embedding_dim, num_heads=num_heads, device=device) # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -122,7 +122,7 @@ class PatchwiseCrossAttention(fl.Chain):
|
||||||
d_model: int,
|
d_model: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
device: torch.device | None = None,
|
device: torch.device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Concatenate(
|
fl.Concatenate(
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
|
|
|
@ -22,7 +22,7 @@ def to_windows(x: Tensor, window_size: int) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class ToWindows(fl.Module):
|
class ToWindows(fl.Module):
|
||||||
def __init__(self, window_size: int):
|
def __init__(self, window_size: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ def get_attn_mask(H: int, window_size: int, device: Device | None = None) -> Ten
|
||||||
|
|
||||||
|
|
||||||
class Pad(fl.Module):
|
class Pad(fl.Module):
|
||||||
def __init__(self, step: int):
|
def __init__(self, step: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.step = step
|
self.step = step
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ class WindowUnflatten(fl.Module):
|
||||||
|
|
||||||
|
|
||||||
class Roll(fl.Module):
|
class Roll(fl.Module):
|
||||||
def __init__(self, *shifts: tuple[int, int]):
|
def __init__(self, *shifts: tuple[int, int]) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shifts = shifts
|
self.shifts = shifts
|
||||||
self._dims = tuple(s[0] for s in shifts)
|
self._dims = tuple(s[0] for s in shifts)
|
||||||
|
@ -148,7 +148,7 @@ class Roll(fl.Module):
|
||||||
class RelativePositionBias(fl.Module):
|
class RelativePositionBias(fl.Module):
|
||||||
relative_position_index: Tensor
|
relative_position_index: Tensor
|
||||||
|
|
||||||
def __init__(self, window_size: int, num_heads: int, device: Device | None = None):
|
def __init__(self, window_size: int, num_heads: int, device: Device | None = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.relative_position_bias_table = torch.nn.Parameter(
|
self.relative_position_bias_table = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
|
@ -178,7 +178,7 @@ class WindowSDPA(fl.Module):
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
shift: bool = False,
|
shift: bool = False,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
@ -220,7 +220,7 @@ class WindowAttention(fl.Chain):
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
shift: bool = False,
|
shift: bool = False,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Linear(dim, dim * 3, bias=True, device=device),
|
fl.Linear(dim, dim * 3, bias=True, device=device),
|
||||||
WindowSDPA(dim, window_size, num_heads, shift, device=device),
|
WindowSDPA(dim, window_size, num_heads, shift, device=device),
|
||||||
|
@ -237,7 +237,7 @@ class SwinTransformerBlock(fl.Chain):
|
||||||
shift_size: int = 0,
|
shift_size: int = 0,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
assert 0 <= shift_size < window_size, "shift_size must in [0, window_size["
|
assert 0 <= shift_size < window_size, "shift_size must in [0, window_size["
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -272,7 +272,7 @@ class SwinTransformerBlock(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class PatchMerging(fl.Chain):
|
class PatchMerging(fl.Chain):
|
||||||
def __init__(self, dim: int, device: Device | None = None):
|
def __init__(self, dim: int, device: Device | None = None) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
SquareUnflatten(1),
|
SquareUnflatten(1),
|
||||||
Pad(2),
|
Pad(2),
|
||||||
|
@ -295,7 +295,7 @@ class BasicLayer(fl.Chain):
|
||||||
window_size: int = 7,
|
window_size: int = 7,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
SwinTransformerBlock(
|
SwinTransformerBlock(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
|
@ -316,7 +316,7 @@ class PatchEmbedding(fl.Chain):
|
||||||
in_chans: int = 3,
|
in_chans: int = 3,
|
||||||
embedding_dim: int = 96,
|
embedding_dim: int = 96,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device),
|
fl.Conv2d(in_chans, embedding_dim, kernel_size=patch_size, stride=patch_size, device=device),
|
||||||
fl.Flatten(2),
|
fl.Flatten(2),
|
||||||
|
@ -341,7 +341,7 @@ class SwinTransformer(fl.Chain):
|
||||||
window_size: int = 7, # image size is 32 * this
|
window_size: int = 7, # image size is 32 * this
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
device: Device | None = None,
|
device: Device | None = None,
|
||||||
):
|
) -> None:
|
||||||
if depths is None:
|
if depths is None:
|
||||||
depths = [2, 2, 6, 2]
|
depths = [2, 2, 6, 2]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue