mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
remove unused class CrossAttention in SAM
This commit is contained in:
parent
a1ad317b00
commit
86867e9318
|
@ -3,29 +3,6 @@ from torch import device as Device, dtype as DType
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(fl.Attention):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embedding_dim: int,
|
|
||||||
cross_embedding_dim: int | None = None,
|
|
||||||
num_heads: int = 1,
|
|
||||||
inner_dim: int | None = None,
|
|
||||||
device: Device | str | None = None,
|
|
||||||
dtype: DType | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
key_embedding_dim=cross_embedding_dim,
|
|
||||||
num_heads=num_heads,
|
|
||||||
inner_dim=inner_dim,
|
|
||||||
is_optimized=False,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
self.cross_embedding_dim = cross_embedding_dim or embedding_dim
|
|
||||||
self.insert(index=0, module=fl.Parallel(fl.GetArg(index=0), fl.GetArg(index=1), fl.GetArg(index=1)))
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(fl.Residual):
|
class FeedForward(fl.Residual):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None
|
self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
|
Loading…
Reference in a new issue