diff --git a/src/refiners/foundationals/latent_diffusion/range_adapter.py b/src/refiners/foundationals/latent_diffusion/range_adapter.py index 317c13b..ce24232 100644 --- a/src/refiners/foundationals/latent_diffusion/range_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/range_adapter.py @@ -24,23 +24,23 @@ def compute_sinusoidal_embedding( class RangeEncoder(fl.Chain): def __init__( self, - sinuosidal_embedding_dim: int, + sinusoidal_embedding_dim: int, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None, ) -> None: - self.sinuosidal_embedding_dim = sinuosidal_embedding_dim + self.sinusoidal_embedding_dim = sinusoidal_embedding_dim self.embedding_dim = embedding_dim super().__init__( fl.Lambda(self.compute_sinuosoidal_embedding), fl.Converter(set_device=False, set_dtype=True), - fl.Linear(in_features=sinuosidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), + fl.Linear(in_features=sinusoidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), fl.SiLU(), fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), ) def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]: - return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim) + return compute_sinusoidal_embedding(x, embedding_dim=self.sinusoidal_embedding_dim) class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index c2b49c8..efa4bc9 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -61,7 +61,7 @@ class TimestepEncoder(fl.Passthrough): fl.Chain( fl.UseContext(context="diffusion", key="timestep"), RangeEncoder( - sinuosidal_embedding_dim=320, + sinusoidal_embedding_dim=320, embedding_dim=self.timestep_embedding_dim, device=device, dtype=dtype,