diff --git a/src/refiners/foundationals/swin/swin_transformer.py b/src/refiners/foundationals/swin/swin_transformer.py index dd1fca4..488819e 100644 --- a/src/refiners/foundationals/swin/swin_transformer.py +++ b/src/refiners/foundationals/swin/swin_transformer.py @@ -173,7 +173,6 @@ class RelativePositionBias(fl.Module): class WindowSDPA(fl.Module): def __init__( self, - dim: int, window_size: int, num_heads: int, shift: bool = False, @@ -223,7 +222,7 @@ class WindowAttention(fl.Chain): ) -> None: super().__init__( fl.Linear(dim, dim * 3, bias=True, device=device), - WindowSDPA(dim, window_size, num_heads, shift, device=device), + WindowSDPA(window_size, num_heads, shift, device=device), fl.Linear(dim, dim, device=device), )