(fluxion/layers) remove View layer

+ replace existing `View` layers by `Reshape`
This commit is contained in:
Laurent 2024-02-02 16:40:40 +00:00 committed by Laureηt
parent 2ef4982e04
commit 9883f24f9a
3 changed files with 1 additions and 12 deletions

View file

@ -28,7 +28,6 @@ from refiners.fluxion.layers.basics import (
Transpose, Transpose,
Unflatten, Unflatten,
Unsqueeze, Unsqueeze,
View,
) )
from refiners.fluxion.layers.chain import ( from refiners.fluxion.layers.chain import (
Breakpoint, Breakpoint,
@ -75,7 +74,6 @@ __all__ = [
"SelfAttention2d", "SelfAttention2d",
"Identity", "Identity",
"GetArg", "GetArg",
"View",
"Flatten", "Flatten",
"Unflatten", "Unflatten",
"Transpose", "Transpose",

View file

@ -28,15 +28,6 @@ class Identity(Module):
return x return x
class View(Module):
def __init__(self, *shape: int) -> None:
super().__init__()
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.view(*self.shape)
class GetArg(Module): class GetArg(Module):
"""GetArg operation layer. """GetArg operation layer.

View file

@ -63,6 +63,6 @@ class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
fl.UseContext("range_adapter", context_key), fl.UseContext("range_adapter", context_key),
fl.SiLU(), fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype), fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
fl.View(-1, channels, 1, 1), fl.Reshape(channels, 1, 1),
), ),
) )