fix typo (sinuosoidal -> sinusoidal)

This commit is contained in:
Pierre Chapuis 2024-01-19 14:33:56 +01:00
parent d34e36797b
commit de6266010d
2 changed files with 4 additions and 4 deletions

View file

@ -32,14 +32,14 @@ class RangeEncoder(fl.Chain):
self.sinusoidal_embedding_dim = sinusoidal_embedding_dim self.sinusoidal_embedding_dim = sinusoidal_embedding_dim
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
super().__init__( super().__init__(
fl.Lambda(self.compute_sinuosoidal_embedding), fl.Lambda(self.compute_sinusoidal_embedding),
fl.Converter(set_device=False, set_dtype=True), fl.Converter(set_device=False, set_dtype=True),
fl.Linear(in_features=sinusoidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), fl.Linear(in_features=sinusoidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
fl.SiLU(), fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype), fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
) )
def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]: def compute_sinusoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]:
return compute_sinusoidal_embedding(x, embedding_dim=self.sinusoidal_embedding_dim) return compute_sinusoidal_embedding(x, embedding_dim=self.sinusoidal_embedding_dim)

View file

@ -28,7 +28,7 @@ class TextTimeEmbedding(fl.Chain):
fl.Chain( fl.Chain(
fl.UseContext(context="diffusion", key="time_ids"), fl.UseContext(context="diffusion", key="time_ids"),
fl.Unsqueeze(dim=-1), fl.Unsqueeze(dim=-1),
fl.Lambda(func=self.compute_sinuosoidal_embedding), fl.Lambda(func=self.compute_sinusoidal_embedding),
fl.Reshape(-1), fl.Reshape(-1),
), ),
dim=1, dim=1,
@ -49,7 +49,7 @@ class TextTimeEmbedding(fl.Chain):
), ),
) )
def compute_sinuosoidal_embedding(self, x: Tensor) -> Tensor: def compute_sinusoidal_embedding(self, x: Tensor) -> Tensor:
return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim) return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim)