diff --git a/src/refiners/fluxion/layers/__init__.py b/src/refiners/fluxion/layers/__init__.py index 4b55c64..dddb809 100644 --- a/src/refiners/fluxion/layers/__init__.py +++ b/src/refiners/fluxion/layers/__init__.py @@ -7,11 +7,17 @@ from refiners.fluxion.layers.basics import ( Flatten, Unflatten, Transpose, + GetArg, Permute, Reshape, Squeeze, Unsqueeze, Slicing, + Sin, + Cos, + Chunk, + Multiply, + Unbind, Parameter, Buffer, ) @@ -28,6 +34,7 @@ from refiners.fluxion.layers.chain import ( Passthrough, Breakpoint, Concatenate, + Matmul, ) from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d from refiners.fluxion.layers.linear import Linear, MultiLinear @@ -53,6 +60,7 @@ __all__ = [ "SelfAttention", "SelfAttention2d", "Identity", + "GetArg", "View", "Flatten", "Unflatten", @@ -63,6 +71,12 @@ __all__ = [ "Reshape", "Slicing", "Parameter", + "Sin", + "Cos", + "Chunk", + "Multiply", + "Unbind", + "Matmul", "Buffer", "Lambda", "Return", diff --git a/src/refiners/fluxion/layers/basics.py b/src/refiners/fluxion/layers/basics.py index 03b48b8..9e6ff1f 100644 --- a/src/refiners/fluxion/layers/basics.py +++ b/src/refiners/fluxion/layers/basics.py @@ -1,4 +1,5 @@ from refiners.fluxion.layers.module import Module, WeightedModule +import torch from torch import randn, Tensor, Size, device as Device, dtype as DType from torch.nn import Parameter as TorchParameter @@ -20,6 +21,15 @@ class View(Module): return x.view(*self.shape) +class GetArg(Module): + def __init__(self, index: int) -> None: + super().__init__() + self.index = index + + def forward(self, *args: Tensor) -> Tensor: + return args[self.index] + + class Flatten(Module): def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None: super().__init__() @@ -101,6 +111,45 @@ class Unsqueeze(Module): return x.unsqueeze(self.dim) +class Unbind(Module): + def __init__(self, dim: int = 0) -> None: + self.dim = dim + super().__init__() + + def forward(self, x: Tensor) -> tuple[Tensor, ...]: + return x.unbind(dim=self.dim) # type: ignore + + +class Chunk(Module): + def __init__(self, chunks: int, dim: int = 0) -> None: + self.chunks = chunks + self.dim = dim + super().__init__() + + def forward(self, x: Tensor) -> tuple[Tensor, ...]: + return x.chunk(chunks=self.chunks, dim=self.dim) # type: ignore + + +class Sin(Module): + def forward(self, x: Tensor) -> Tensor: + return torch.sin(input=x) + + +class Cos(Module): + def forward(self, x: Tensor) -> Tensor: + return torch.cos(input=x) + + +class Multiply(Module): + def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None: + super().__init__() + self.scale = scale + self.bias = bias + + def forward(self, x: Tensor) -> Tensor: + return self.scale * x + self.bias + + class Parameter(WeightedModule): """ A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias. diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index a6081f9..1a36747 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -1,5 +1,6 @@ import inspect from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload +import torch from torch import Tensor, cat, device as Device, dtype as DType from refiners.fluxion.layers.basics import Identity from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule @@ -483,3 +484,16 @@ class Concatenate(Chain): def _show_only_tag(self) -> bool: return self.__class__ == Concatenate + + +class Matmul(Chain): + _tag = "MATMUL" + + def __init__(self, input: Module, other: Module) -> None: + super().__init__( + input, + other, + ) + + def forward(self, *args: Tensor) -> Tensor: + return torch.matmul(input=self[0](*args), other=self[1](*args))