add new basic layers and Matmul chain

This commit is contained in:
Benjamin Trom 2023-09-12 10:51:50 +02:00
parent 2f2510a9b1
commit b515c02867
3 changed files with 77 additions and 0 deletions

View file

@ -7,11 +7,17 @@ from refiners.fluxion.layers.basics import (
Flatten, Flatten,
Unflatten, Unflatten,
Transpose, Transpose,
GetArg,
Permute, Permute,
Reshape, Reshape,
Squeeze, Squeeze,
Unsqueeze, Unsqueeze,
Slicing, Slicing,
Sin,
Cos,
Chunk,
Multiply,
Unbind,
Parameter, Parameter,
Buffer, Buffer,
) )
@ -28,6 +34,7 @@ from refiners.fluxion.layers.chain import (
Passthrough, Passthrough,
Breakpoint, Breakpoint,
Concatenate, Concatenate,
Matmul,
) )
from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
from refiners.fluxion.layers.linear import Linear, MultiLinear from refiners.fluxion.layers.linear import Linear, MultiLinear
@ -53,6 +60,7 @@ __all__ = [
"SelfAttention", "SelfAttention",
"SelfAttention2d", "SelfAttention2d",
"Identity", "Identity",
"GetArg",
"View", "View",
"Flatten", "Flatten",
"Unflatten", "Unflatten",
@ -63,6 +71,12 @@ __all__ = [
"Reshape", "Reshape",
"Slicing", "Slicing",
"Parameter", "Parameter",
"Sin",
"Cos",
"Chunk",
"Multiply",
"Unbind",
"Matmul",
"Buffer", "Buffer",
"Lambda", "Lambda",
"Return", "Return",

View file

@ -1,4 +1,5 @@
from refiners.fluxion.layers.module import Module, WeightedModule from refiners.fluxion.layers.module import Module, WeightedModule
import torch
from torch import randn, Tensor, Size, device as Device, dtype as DType from torch import randn, Tensor, Size, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter from torch.nn import Parameter as TorchParameter
@ -20,6 +21,15 @@ class View(Module):
return x.view(*self.shape) return x.view(*self.shape)
class GetArg(Module):
def __init__(self, index: int) -> None:
super().__init__()
self.index = index
def forward(self, *args: Tensor) -> Tensor:
return args[self.index]
class Flatten(Module): class Flatten(Module):
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None: def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
super().__init__() super().__init__()
@ -101,6 +111,45 @@ class Unsqueeze(Module):
return x.unsqueeze(self.dim) return x.unsqueeze(self.dim)
class Unbind(Module):
def __init__(self, dim: int = 0) -> None:
self.dim = dim
super().__init__()
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
return x.unbind(dim=self.dim) # type: ignore
class Chunk(Module):
def __init__(self, chunks: int, dim: int = 0) -> None:
self.chunks = chunks
self.dim = dim
super().__init__()
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
return x.chunk(chunks=self.chunks, dim=self.dim) # type: ignore
class Sin(Module):
def forward(self, x: Tensor) -> Tensor:
return torch.sin(input=x)
class Cos(Module):
def forward(self, x: Tensor) -> Tensor:
return torch.cos(input=x)
class Multiply(Module):
def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None:
super().__init__()
self.scale = scale
self.bias = bias
def forward(self, x: Tensor) -> Tensor:
return self.scale * x + self.bias
class Parameter(WeightedModule): class Parameter(WeightedModule):
""" """
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias. A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.

View file

@ -1,5 +1,6 @@
import inspect import inspect
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
import torch
from torch import Tensor, cat, device as Device, dtype as DType from torch import Tensor, cat, device as Device, dtype as DType
from refiners.fluxion.layers.basics import Identity from refiners.fluxion.layers.basics import Identity
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
@ -483,3 +484,16 @@ class Concatenate(Chain):
def _show_only_tag(self) -> bool: def _show_only_tag(self) -> bool:
return self.__class__ == Concatenate return self.__class__ == Concatenate
class Matmul(Chain):
_tag = "MATMUL"
def __init__(self, input: Module, other: Module) -> None:
super().__init__(
input,
other,
)
def forward(self, *args: Tensor) -> Tensor:
return torch.matmul(input=self[0](*args), other=self[1](*args))