mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(fluxion/layers) remove View
layer
+ replace existing `View` layers by `Reshape`
This commit is contained in:
parent
2ef4982e04
commit
9883f24f9a
|
@ -28,7 +28,6 @@ from refiners.fluxion.layers.basics import (
|
|||
Transpose,
|
||||
Unflatten,
|
||||
Unsqueeze,
|
||||
View,
|
||||
)
|
||||
from refiners.fluxion.layers.chain import (
|
||||
Breakpoint,
|
||||
|
@ -75,7 +74,6 @@ __all__ = [
|
|||
"SelfAttention2d",
|
||||
"Identity",
|
||||
"GetArg",
|
||||
"View",
|
||||
"Flatten",
|
||||
"Unflatten",
|
||||
"Transpose",
|
||||
|
|
|
@ -28,15 +28,6 @@ class Identity(Module):
|
|||
return x
|
||||
|
||||
|
||||
class View(Module):
|
||||
def __init__(self, *shape: int) -> None:
|
||||
super().__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.view(*self.shape)
|
||||
|
||||
|
||||
class GetArg(Module):
|
||||
"""GetArg operation layer.
|
||||
|
||||
|
|
|
@ -63,6 +63,6 @@ class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
|||
fl.UseContext("range_adapter", context_key),
|
||||
fl.SiLU(),
|
||||
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
|
||||
fl.View(-1, channels, 1, 1),
|
||||
fl.Reshape(channels, 1, 1),
|
||||
),
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue