remove unused argument in WindowSDPA
Some checks failed
CI / lint_and_typecheck (push) Has been cancelled
Deploy docs to GitHub Pages / Deploy docs (push) Has been cancelled
Spell checker / Spell check (push) Has been cancelled

This commit is contained in:
Pierre Chapuis 2024-09-12 15:22:40 +02:00
parent 31b5f80496
commit 336253f26b

View file

@ -173,7 +173,6 @@ class RelativePositionBias(fl.Module):
class WindowSDPA(fl.Module): class WindowSDPA(fl.Module):
def __init__( def __init__(
self, self,
dim: int,
window_size: int, window_size: int,
num_heads: int, num_heads: int,
shift: bool = False, shift: bool = False,
@ -223,7 +222,7 @@ class WindowAttention(fl.Chain):
) -> None: ) -> None:
super().__init__( super().__init__(
fl.Linear(dim, dim * 3, bias=True, device=device), fl.Linear(dim, dim * 3, bias=True, device=device),
WindowSDPA(dim, window_size, num_heads, shift, device=device), WindowSDPA(window_size, num_heads, shift, device=device),
fl.Linear(dim, dim, device=device), fl.Linear(dim, dim, device=device),
) )