remove unused class CrossAttention in SAM

This commit is contained in:
Pierre Chapuis 2024-01-29 16:34:35 +01:00 committed by Cédric Deltheil
parent a1ad317b00
commit 86867e9318

View file

@ -3,29 +3,6 @@ from torch import device as Device, dtype as DType
import refiners.fluxion.layers as fl
class CrossAttention(fl.Attention):
def __init__(
self,
embedding_dim: int,
cross_embedding_dim: int | None = None,
num_heads: int = 1,
inner_dim: int | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
embedding_dim=embedding_dim,
key_embedding_dim=cross_embedding_dim,
num_heads=num_heads,
inner_dim=inner_dim,
is_optimized=False,
device=device,
dtype=dtype,
)
self.cross_embedding_dim = cross_embedding_dim or embedding_dim
self.insert(index=0, module=fl.Parallel(fl.GetArg(index=0), fl.GetArg(index=1), fl.GetArg(index=1)))
class FeedForward(fl.Residual):
def __init__(
self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None