use Converter layer for sinuosoidal embedding

This commit is contained in:
Cédric Deltheil 2023-08-21 14:34:40 +02:00 committed by Cédric Deltheil
parent 108fa8f26a
commit b91a457495
2 changed files with 4 additions and 2 deletions

View file

@ -34,13 +34,14 @@ class RangeEncoder(fl.Chain):
self.embedding_dim = embedding_dim
super().__init__(
fl.Lambda(self.compute_sinuosoidal_embedding),
fl.Converter(set_device=False, set_dtype=True),
fl.Linear(in_features=sinuosidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
)
def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]:
return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim).to(self.dtype)
return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim)
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):

View file

@ -25,6 +25,7 @@ class TextTimeEmbedding(fl.Chain):
),
dim=1,
),
fl.Converter(set_device=False, set_dtype=True),
fl.Linear(
in_features=self.text_time_embedding_dim,
out_features=self.timestep_embedding_dim,
@ -41,7 +42,7 @@ class TextTimeEmbedding(fl.Chain):
)
def compute_sinuosoidal_embedding(self, x: Tensor) -> Tensor:
return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim).to(dtype=self.dtype)
return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim)
class TimestepEncoder(fl.Passthrough):