(fluxion/layers) remove View layer

+ replace existing `View` layers by `Reshape`
This commit is contained in:
Laurent 2024-02-02 16:40:40 +00:00 committed by Laureηt
parent 2ef4982e04
commit 9883f24f9a
3 changed files with 1 additions and 12 deletions

View file

@ -28,7 +28,6 @@ from refiners.fluxion.layers.basics import (
Transpose,
Unflatten,
Unsqueeze,
View,
)
from refiners.fluxion.layers.chain import (
Breakpoint,
@ -75,7 +74,6 @@ __all__ = [
"SelfAttention2d",
"Identity",
"GetArg",
"View",
"Flatten",
"Unflatten",
"Transpose",

View file

@ -28,15 +28,6 @@ class Identity(Module):
return x
class View(Module):
def __init__(self, *shape: int) -> None:
super().__init__()
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.view(*self.shape)
class GetArg(Module):
"""GetArg operation layer.

View file

@ -63,6 +63,6 @@ class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
fl.UseContext("range_adapter", context_key),
fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
fl.View(-1, channels, 1, 1),
fl.Reshape(channels, 1, 1),
),
)