mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
fix typo (sinuosidal -> sinusoidal)
This commit is contained in:
parent
fde61757fb
commit
d34e36797b
|
@ -24,23 +24,23 @@ def compute_sinusoidal_embedding(
|
||||||
class RangeEncoder(fl.Chain):
|
class RangeEncoder(fl.Chain):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sinuosidal_embedding_dim: int,
|
sinusoidal_embedding_dim: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.sinuosidal_embedding_dim = sinuosidal_embedding_dim
|
self.sinusoidal_embedding_dim = sinusoidal_embedding_dim
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Lambda(self.compute_sinuosoidal_embedding),
|
fl.Lambda(self.compute_sinuosoidal_embedding),
|
||||||
fl.Converter(set_device=False, set_dtype=True),
|
fl.Converter(set_device=False, set_dtype=True),
|
||||||
fl.Linear(in_features=sinuosidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
fl.Linear(in_features=sinusoidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
fl.SiLU(),
|
fl.SiLU(),
|
||||||
fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]:
|
def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]:
|
||||||
return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim)
|
return compute_sinusoidal_embedding(x, embedding_dim=self.sinusoidal_embedding_dim)
|
||||||
|
|
||||||
|
|
||||||
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
||||||
|
|
|
@ -61,7 +61,7 @@ class TimestepEncoder(fl.Passthrough):
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.UseContext(context="diffusion", key="timestep"),
|
fl.UseContext(context="diffusion", key="timestep"),
|
||||||
RangeEncoder(
|
RangeEncoder(
|
||||||
sinuosidal_embedding_dim=320,
|
sinusoidal_embedding_dim=320,
|
||||||
embedding_dim=self.timestep_embedding_dim,
|
embedding_dim=self.timestep_embedding_dim,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
Loading…
Reference in a new issue