mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
remove unused Chunk and Unbind layers
This commit is contained in:
parent
c57f2228f8
commit
e6be1394ff
|
@ -2,7 +2,6 @@ from refiners.fluxion.layers.activations import GLU, ApproximateGeLU, GeLU, ReLU
|
||||||
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
|
from refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
|
||||||
from refiners.fluxion.layers.basics import (
|
from refiners.fluxion.layers.basics import (
|
||||||
Buffer,
|
Buffer,
|
||||||
Chunk,
|
|
||||||
Cos,
|
Cos,
|
||||||
Flatten,
|
Flatten,
|
||||||
GetArg,
|
GetArg,
|
||||||
|
@ -15,7 +14,6 @@ from refiners.fluxion.layers.basics import (
|
||||||
Slicing,
|
Slicing,
|
||||||
Squeeze,
|
Squeeze,
|
||||||
Transpose,
|
Transpose,
|
||||||
Unbind,
|
|
||||||
Unflatten,
|
Unflatten,
|
||||||
Unsqueeze,
|
Unsqueeze,
|
||||||
View,
|
View,
|
||||||
|
@ -75,9 +73,7 @@ __all__ = [
|
||||||
"Parameter",
|
"Parameter",
|
||||||
"Sin",
|
"Sin",
|
||||||
"Cos",
|
"Cos",
|
||||||
"Chunk",
|
|
||||||
"Multiply",
|
"Multiply",
|
||||||
"Unbind",
|
|
||||||
"Matmul",
|
"Matmul",
|
||||||
"Buffer",
|
"Buffer",
|
||||||
"Lambda",
|
"Lambda",
|
||||||
|
|
|
@ -130,25 +130,6 @@ class Unsqueeze(Module):
|
||||||
return x.unsqueeze(self.dim)
|
return x.unsqueeze(self.dim)
|
||||||
|
|
||||||
|
|
||||||
class Unbind(Module):
|
|
||||||
def __init__(self, dim: int = 0) -> None:
|
|
||||||
self.dim = dim
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
|
|
||||||
return x.unbind(dim=self.dim) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class Chunk(Module):
|
|
||||||
def __init__(self, chunks: int, dim: int = 0) -> None:
|
|
||||||
self.chunks = chunks
|
|
||||||
self.dim = dim
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
|
|
||||||
return x.chunk(chunks=self.chunks, dim=self.dim) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class Sin(Module):
|
class Sin(Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return torch.sin(input=x)
|
return torch.sin(input=x)
|
||||||
|
|
Loading…
Reference in a new issue