mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
add new basic layers and Matmul chain
This commit is contained in:
parent
2f2510a9b1
commit
b515c02867
|
@ -7,11 +7,17 @@ from refiners.fluxion.layers.basics import (
|
|||
Flatten,
|
||||
Unflatten,
|
||||
Transpose,
|
||||
GetArg,
|
||||
Permute,
|
||||
Reshape,
|
||||
Squeeze,
|
||||
Unsqueeze,
|
||||
Slicing,
|
||||
Sin,
|
||||
Cos,
|
||||
Chunk,
|
||||
Multiply,
|
||||
Unbind,
|
||||
Parameter,
|
||||
Buffer,
|
||||
)
|
||||
|
@ -28,6 +34,7 @@ from refiners.fluxion.layers.chain import (
|
|||
Passthrough,
|
||||
Breakpoint,
|
||||
Concatenate,
|
||||
Matmul,
|
||||
)
|
||||
from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
|
||||
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
||||
|
@ -53,6 +60,7 @@ __all__ = [
|
|||
"SelfAttention",
|
||||
"SelfAttention2d",
|
||||
"Identity",
|
||||
"GetArg",
|
||||
"View",
|
||||
"Flatten",
|
||||
"Unflatten",
|
||||
|
@ -63,6 +71,12 @@ __all__ = [
|
|||
"Reshape",
|
||||
"Slicing",
|
||||
"Parameter",
|
||||
"Sin",
|
||||
"Cos",
|
||||
"Chunk",
|
||||
"Multiply",
|
||||
"Unbind",
|
||||
"Matmul",
|
||||
"Buffer",
|
||||
"Lambda",
|
||||
"Return",
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from refiners.fluxion.layers.module import Module, WeightedModule
|
||||
import torch
|
||||
from torch import randn, Tensor, Size, device as Device, dtype as DType
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
|
||||
|
@ -20,6 +21,15 @@ class View(Module):
|
|||
return x.view(*self.shape)
|
||||
|
||||
|
||||
class GetArg(Module):
|
||||
def __init__(self, index: int) -> None:
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
def forward(self, *args: Tensor) -> Tensor:
|
||||
return args[self.index]
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||
super().__init__()
|
||||
|
@ -101,6 +111,45 @@ class Unsqueeze(Module):
|
|||
return x.unsqueeze(self.dim)
|
||||
|
||||
|
||||
class Unbind(Module):
|
||||
def __init__(self, dim: int = 0) -> None:
|
||||
self.dim = dim
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
|
||||
return x.unbind(dim=self.dim) # type: ignore
|
||||
|
||||
|
||||
class Chunk(Module):
|
||||
def __init__(self, chunks: int, dim: int = 0) -> None:
|
||||
self.chunks = chunks
|
||||
self.dim = dim
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
|
||||
return x.chunk(chunks=self.chunks, dim=self.dim) # type: ignore
|
||||
|
||||
|
||||
class Sin(Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return torch.sin(input=x)
|
||||
|
||||
|
||||
class Cos(Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return torch.cos(input=x)
|
||||
|
||||
|
||||
class Multiply(Module):
|
||||
def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None:
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.scale * x + self.bias
|
||||
|
||||
|
||||
class Parameter(WeightedModule):
|
||||
"""
|
||||
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import inspect
|
||||
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
|
||||
import torch
|
||||
from torch import Tensor, cat, device as Device, dtype as DType
|
||||
from refiners.fluxion.layers.basics import Identity
|
||||
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
|
||||
|
@ -483,3 +484,16 @@ class Concatenate(Chain):
|
|||
|
||||
def _show_only_tag(self) -> bool:
|
||||
return self.__class__ == Concatenate
|
||||
|
||||
|
||||
class Matmul(Chain):
|
||||
_tag = "MATMUL"
|
||||
|
||||
def __init__(self, input: Module, other: Module) -> None:
|
||||
super().__init__(
|
||||
input,
|
||||
other,
|
||||
)
|
||||
|
||||
def forward(self, *args: Tensor) -> Tensor:
|
||||
return torch.matmul(input=self[0](*args), other=self[1](*args))
|
||||
|
|
Loading…
Reference in a new issue