mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
add new basic layers and Matmul chain
This commit is contained in:
parent
2f2510a9b1
commit
b515c02867
|
@ -7,11 +7,17 @@ from refiners.fluxion.layers.basics import (
|
||||||
Flatten,
|
Flatten,
|
||||||
Unflatten,
|
Unflatten,
|
||||||
Transpose,
|
Transpose,
|
||||||
|
GetArg,
|
||||||
Permute,
|
Permute,
|
||||||
Reshape,
|
Reshape,
|
||||||
Squeeze,
|
Squeeze,
|
||||||
Unsqueeze,
|
Unsqueeze,
|
||||||
Slicing,
|
Slicing,
|
||||||
|
Sin,
|
||||||
|
Cos,
|
||||||
|
Chunk,
|
||||||
|
Multiply,
|
||||||
|
Unbind,
|
||||||
Parameter,
|
Parameter,
|
||||||
Buffer,
|
Buffer,
|
||||||
)
|
)
|
||||||
|
@ -28,6 +34,7 @@ from refiners.fluxion.layers.chain import (
|
||||||
Passthrough,
|
Passthrough,
|
||||||
Breakpoint,
|
Breakpoint,
|
||||||
Concatenate,
|
Concatenate,
|
||||||
|
Matmul,
|
||||||
)
|
)
|
||||||
from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
|
from refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
|
||||||
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
from refiners.fluxion.layers.linear import Linear, MultiLinear
|
||||||
|
@ -53,6 +60,7 @@ __all__ = [
|
||||||
"SelfAttention",
|
"SelfAttention",
|
||||||
"SelfAttention2d",
|
"SelfAttention2d",
|
||||||
"Identity",
|
"Identity",
|
||||||
|
"GetArg",
|
||||||
"View",
|
"View",
|
||||||
"Flatten",
|
"Flatten",
|
||||||
"Unflatten",
|
"Unflatten",
|
||||||
|
@ -63,6 +71,12 @@ __all__ = [
|
||||||
"Reshape",
|
"Reshape",
|
||||||
"Slicing",
|
"Slicing",
|
||||||
"Parameter",
|
"Parameter",
|
||||||
|
"Sin",
|
||||||
|
"Cos",
|
||||||
|
"Chunk",
|
||||||
|
"Multiply",
|
||||||
|
"Unbind",
|
||||||
|
"Matmul",
|
||||||
"Buffer",
|
"Buffer",
|
||||||
"Lambda",
|
"Lambda",
|
||||||
"Return",
|
"Return",
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from refiners.fluxion.layers.module import Module, WeightedModule
|
from refiners.fluxion.layers.module import Module, WeightedModule
|
||||||
|
import torch
|
||||||
from torch import randn, Tensor, Size, device as Device, dtype as DType
|
from torch import randn, Tensor, Size, device as Device, dtype as DType
|
||||||
from torch.nn import Parameter as TorchParameter
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
|
||||||
|
@ -20,6 +21,15 @@ class View(Module):
|
||||||
return x.view(*self.shape)
|
return x.view(*self.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class GetArg(Module):
|
||||||
|
def __init__(self, index: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def forward(self, *args: Tensor) -> Tensor:
|
||||||
|
return args[self.index]
|
||||||
|
|
||||||
|
|
||||||
class Flatten(Module):
|
class Flatten(Module):
|
||||||
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
|
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -101,6 +111,45 @@ class Unsqueeze(Module):
|
||||||
return x.unsqueeze(self.dim)
|
return x.unsqueeze(self.dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Unbind(Module):
|
||||||
|
def __init__(self, dim: int = 0) -> None:
|
||||||
|
self.dim = dim
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
|
||||||
|
return x.unbind(dim=self.dim) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(Module):
|
||||||
|
def __init__(self, chunks: int, dim: int = 0) -> None:
|
||||||
|
self.chunks = chunks
|
||||||
|
self.dim = dim
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
|
||||||
|
return x.chunk(chunks=self.chunks, dim=self.dim) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Sin(Module):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.sin(input=x)
|
||||||
|
|
||||||
|
|
||||||
|
class Cos(Module):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.cos(input=x)
|
||||||
|
|
||||||
|
|
||||||
|
class Multiply(Module):
|
||||||
|
def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.scale * x + self.bias
|
||||||
|
|
||||||
|
|
||||||
class Parameter(WeightedModule):
|
class Parameter(WeightedModule):
|
||||||
"""
|
"""
|
||||||
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
|
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
|
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
|
||||||
|
import torch
|
||||||
from torch import Tensor, cat, device as Device, dtype as DType
|
from torch import Tensor, cat, device as Device, dtype as DType
|
||||||
from refiners.fluxion.layers.basics import Identity
|
from refiners.fluxion.layers.basics import Identity
|
||||||
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
|
from refiners.fluxion.layers.module import Module, ContextModule, WeightedModule
|
||||||
|
@ -483,3 +484,16 @@ class Concatenate(Chain):
|
||||||
|
|
||||||
def _show_only_tag(self) -> bool:
|
def _show_only_tag(self) -> bool:
|
||||||
return self.__class__ == Concatenate
|
return self.__class__ == Concatenate
|
||||||
|
|
||||||
|
|
||||||
|
class Matmul(Chain):
|
||||||
|
_tag = "MATMUL"
|
||||||
|
|
||||||
|
def __init__(self, input: Module, other: Module) -> None:
|
||||||
|
super().__init__(
|
||||||
|
input,
|
||||||
|
other,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, *args: Tensor) -> Tensor:
|
||||||
|
return torch.matmul(input=self[0](*args), other=self[1](*args))
|
||||||
|
|
Loading…
Reference in a new issue