mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
remove unused class CrossAttention in SAM
This commit is contained in:
parent
a1ad317b00
commit
86867e9318
|
@ -3,29 +3,6 @@ from torch import device as Device, dtype as DType
|
|||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
class CrossAttention(fl.Attention):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
cross_embedding_dim: int | None = None,
|
||||
num_heads: int = 1,
|
||||
inner_dim: int | None = None,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embedding_dim=embedding_dim,
|
||||
key_embedding_dim=cross_embedding_dim,
|
||||
num_heads=num_heads,
|
||||
inner_dim=inner_dim,
|
||||
is_optimized=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.cross_embedding_dim = cross_embedding_dim or embedding_dim
|
||||
self.insert(index=0, module=fl.Parallel(fl.GetArg(index=0), fl.GetArg(index=1), fl.GetArg(index=1)))
|
||||
|
||||
|
||||
class FeedForward(fl.Residual):
|
||||
def __init__(
|
||||
self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None
|
||||
|
|
Loading…
Reference in a new issue