diff --git a/src/refiners/fluxion/layers/__init__.py b/src/refiners/fluxion/layers/__init__.py index 610f5d4..813307c 100644 --- a/src/refiners/fluxion/layers/__init__.py +++ b/src/refiners/fluxion/layers/__init__.py @@ -28,7 +28,6 @@ from refiners.fluxion.layers.basics import ( Transpose, Unflatten, Unsqueeze, - View, ) from refiners.fluxion.layers.chain import ( Breakpoint, @@ -75,7 +74,6 @@ __all__ = [ "SelfAttention2d", "Identity", "GetArg", - "View", "Flatten", "Unflatten", "Transpose", diff --git a/src/refiners/fluxion/layers/basics.py b/src/refiners/fluxion/layers/basics.py index 93ec638..6ae89f7 100644 --- a/src/refiners/fluxion/layers/basics.py +++ b/src/refiners/fluxion/layers/basics.py @@ -28,15 +28,6 @@ class Identity(Module): return x -class View(Module): - def __init__(self, *shape: int) -> None: - super().__init__() - self.shape = shape - - def forward(self, x: Tensor) -> Tensor: - return x.view(*self.shape) - - class GetArg(Module): """GetArg operation layer. diff --git a/src/refiners/foundationals/latent_diffusion/range_adapter.py b/src/refiners/foundationals/latent_diffusion/range_adapter.py index fc1da00..0717c75 100644 --- a/src/refiners/foundationals/latent_diffusion/range_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/range_adapter.py @@ -63,6 +63,6 @@ class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]): fl.UseContext("range_adapter", context_key), fl.SiLU(), fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype), - fl.View(-1, channels, 1, 1), + fl.Reshape(channels, 1, 1), ), )