mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(fluxion/layers) remove View
layer
+ replace existing `View` layers by `Reshape`
This commit is contained in:
parent
2ef4982e04
commit
9883f24f9a
|
@ -28,7 +28,6 @@ from refiners.fluxion.layers.basics import (
|
||||||
Transpose,
|
Transpose,
|
||||||
Unflatten,
|
Unflatten,
|
||||||
Unsqueeze,
|
Unsqueeze,
|
||||||
View,
|
|
||||||
)
|
)
|
||||||
from refiners.fluxion.layers.chain import (
|
from refiners.fluxion.layers.chain import (
|
||||||
Breakpoint,
|
Breakpoint,
|
||||||
|
@ -75,7 +74,6 @@ __all__ = [
|
||||||
"SelfAttention2d",
|
"SelfAttention2d",
|
||||||
"Identity",
|
"Identity",
|
||||||
"GetArg",
|
"GetArg",
|
||||||
"View",
|
|
||||||
"Flatten",
|
"Flatten",
|
||||||
"Unflatten",
|
"Unflatten",
|
||||||
"Transpose",
|
"Transpose",
|
||||||
|
|
|
@ -28,15 +28,6 @@ class Identity(Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class View(Module):
|
|
||||||
def __init__(self, *shape: int) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.shape = shape
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return x.view(*self.shape)
|
|
||||||
|
|
||||||
|
|
||||||
class GetArg(Module):
|
class GetArg(Module):
|
||||||
"""GetArg operation layer.
|
"""GetArg operation layer.
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,6 @@ class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
||||||
fl.UseContext("range_adapter", context_key),
|
fl.UseContext("range_adapter", context_key),
|
||||||
fl.SiLU(),
|
fl.SiLU(),
|
||||||
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
|
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
|
||||||
fl.View(-1, channels, 1, 1),
|
fl.Reshape(channels, 1, 1),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue