diff --git a/src/refiners/foundationals/segment_anything/transformer.py b/src/refiners/foundationals/segment_anything/transformer.py index 7abe2f5..855710c 100644 --- a/src/refiners/foundationals/segment_anything/transformer.py +++ b/src/refiners/foundationals/segment_anything/transformer.py @@ -3,29 +3,6 @@ from torch import device as Device, dtype as DType import refiners.fluxion.layers as fl -class CrossAttention(fl.Attention): - def __init__( - self, - embedding_dim: int, - cross_embedding_dim: int | None = None, - num_heads: int = 1, - inner_dim: int | None = None, - device: Device | str | None = None, - dtype: DType | None = None, - ) -> None: - super().__init__( - embedding_dim=embedding_dim, - key_embedding_dim=cross_embedding_dim, - num_heads=num_heads, - inner_dim=inner_dim, - is_optimized=False, - device=device, - dtype=dtype, - ) - self.cross_embedding_dim = cross_embedding_dim or embedding_dim - self.insert(index=0, module=fl.Parallel(fl.GetArg(index=0), fl.GetArg(index=1), fl.GetArg(index=1))) - - class FeedForward(fl.Residual): def __init__( self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None